Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jan 17, 2025
1 parent 9bf08fb commit 8baa2b2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 9 deletions.
11 changes: 5 additions & 6 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
# function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@non_differentiable unthunk_tangent(::IdDict)


Expand Down
6 changes: 3 additions & 3 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ julia> gradient([7, 11], 0, 1) do x, y, d
"""
function gradient(f, args...)
y, back = pullback(f, args...)
grad = unthunk_tangent(back(sensitivity(y)))
return _project_all(args, grad)
grad = back(sensitivity(y))
return _project_all(args, unthunk_tangent(grad))
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand Down Expand Up @@ -218,7 +218,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = _project_all(args, grad)
results = _project_all(args, unthunk_tangent(grad))
(val=y, grad=results)
end

Expand Down
28 changes: 28 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,31 @@ end
@test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]
@test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible
end

@testset "Lazy" begin
custom_add(x, y) = x + y
function ChainRulesCore.rrule(::typeof(custom_add), x, y)
function pullback(Δ)
return NoTangent(), unthunk(Δ), @thunk(error("Should not compute."))
end
custom_add(x, y), pullback
end

x, y = 1f0, 1f0
Zygote.gradient(x) do x
sum(custom_add(x, y))
end
end

@testset "No thunks in the gradient" begin
struct Dense
w::Matrix{Float32}
end
(d::Dense)(x) = d.w * x

layers = [Dense(rand(Float32, 3, 3))]
x = ones(Float32, 3)
g = gradient(layers -> sum(layers[1](x)), layers)[1]
@test g[1] isa NamedTuple
@test g[1].w isa Array
end

0 comments on commit 8baa2b2

Please sign in to comment.