diff --git a/Manifest.toml b/Manifest.toml index a761c30a3..c9547fc45 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -53,6 +53,11 @@ version = "1.0.1" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[ExprTools]] +git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.1" + [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] git-tree-sha1 = "51cc2f9bc4eb9c6c0e81ec2f779d1085583cc956" @@ -157,6 +162,12 @@ git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "0.10.0" +[[SpecializeVarargs]] +deps = ["ExprTools"] +git-tree-sha1 = "198d9939074e645b816c3c7c857946b258e7cc43" +uuid = "24973c7f-061f-47f0-b8d1-653b711ffc2d" +version = "0.1.1" + [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c" diff --git a/Project.toml b/Project.toml index dc53befbc..cea906cb4 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +SpecializeVarargs = "24973c7f-061f-47f0-b8d1-653b711ffc2d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/src/Zygote.jl b/src/Zygote.jl index 11b6fe134..2871db480 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -10,6 +10,8 @@ using IRTools using MacroTools, Requires using MacroTools: @forward +using SpecializeVarargs + export Params, gradient, pullback, @code_grad include("tools/idset.jl") diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 4c8a76179..6e7935cb0 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -26,12 +26,12 @@ end # Wrappers -_pullback(f, args...) = _pullback(Context(), f, args...) +@specialize_vararg 5 _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing tailmemaybe(x::Tuple) = Base.tail(x) -function pullback(f, args...) +@specialize_vararg 5 function pullback(f, args...) y, back = _pullback(f, args...) y, Δ -> tailmemaybe(back(Δ)) end @@ -40,7 +40,7 @@ sensitivity(y::Number) = one(y) sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.") sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") -function gradient(f, args...) +@specialize_vararg 5 function gradient(f, args...) y, back = pullback(f, args...) return back(sensitivity(y)) end diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index b5b4e0816..27af74dda 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -16,6 +16,20 @@ ignore(T) = all(T -> T <: Type, T.parameters) return update!(meta.code, forw) end +@generated function _pullback(ctx::AContext, f, x) + #Core.println("p2: ", x) + T = Tuple{f,x} + ignore(T) && return :(f(x), Pullback{$T}(())) + g = try _lookup_grad(T) catch e e end + !(g isa Tuple) && return :(f(x), Pullback{$T}((f,))) + meta, forw, _ = g + argnames!(meta, Symbol("#self#"), :ctx, :f, :x) + # IRTools.verify(forw) + forw = slots!(pis!(inlineable!(forw))) + return update!(meta.code, forw) +end + + @generated function (j::Pullback{T})(Δ) where T ignore(T) && return :nothing g = try _lookup_grad(T)