From e2fcd614478199d7839d11cbf70b36bb6335ae95 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 20 Aug 2019 12:48:40 +0100 Subject: [PATCH] gradients of channels --- src/compiler/reverse.jl | 2 +- src/lib/base.jl | 20 ++++++++++++++++++++ test/runtests.jl | 4 ++++ test/structures.jl | 8 ++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 test/structures.jl diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index e3bb16b11..c4891189e 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -84,7 +84,7 @@ function instrument(ir::IR) pr = Pipe(ir) for (v, st) in pr ex = st.expr - isexpr(ex, :foreigncall) && continue + isexpr(ex, :foreigncall, :isdefined) && continue isexpr(ex, :enter, :leave) && error("try/catch is not supported.") ex = instrument_new!(pr, v, ex) ex = instrument_literals!(pr, v, ex) diff --git a/src/lib/base.jl b/src/lib/base.jl index 91c2fca5d..2732cbaf9 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -46,3 +46,23 @@ end (nothing, Δ, nothing) end end + +# Channels + +@nograd Channel + +grad_mut(ch::Channel) = Channel(ch.sz_max) + +@adjoint! function put!(ch::Channel, x) + put!(ch, x), function (ȳ) + x̄ = take!(grad_mut(__context__, ch)) + (nothing, accum(x̄, ȳ), nothing) + end +end + +@adjoint! function take!(ch::Channel) + take!(ch), function (x̄) + put!(grad_mut(__context__, ch), x̄) + return + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 0142a6467..b14ed2fdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,10 @@ end include("features.jl") end +@testset "Data Structures" begin + include("structures.jl") +end + @testset "Gradients" begin include("gradcheck.jl") end diff --git a/test/structures.jl b/test/structures.jl new file mode 100644 index 000000000..6e04c68ce --- /dev/null +++ b/test/structures.jl @@ -0,0 +1,8 @@ +using Zygote, Test + +function f(x) + ch = Channel(Inf) + put!(ch, x^2) + take!(ch) +end + +@test gradient(f, 5) == (20,)