Skip to content

Commit

Permalink
Specialize logdensityof for DensityMeasure
Browse files Browse the repository at this point in the history
Ensures proper type propagation (until future refactor of density calculation
engine).
  • Loading branch information
oschulz committed Nov 4, 2024
1 parent 31dd7c2 commit 113196a
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 7 deletions.
19 changes: 19 additions & 0 deletions ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ using ChainRulesCore: NoTangent, ZeroTangent
import ChainRulesCore


# = utils ====================================================================

using MeasureBase: isneginf, isposinf

_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback

_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback


# = insupport & friends ======================================================

using MeasureBase:
Expand Down Expand Up @@ -44,4 +55,12 @@ _check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback


# = return type inference ====================================================

using MeasureBase: logdensityof_rt

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v) = logdensityof_rt(target, v), _logdensityof_rt_pullback


end # module MeasureBaseChainRulesCoreExt
4 changes: 4 additions & 0 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ To compute a log-density relative to a specific base-measure, see
_checksupport(insupport(μ, x), result)
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Core.Compiler.return_type(logdensityof, Tuple{T,U})
end

_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))


Expand Down
19 changes: 19 additions & 0 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)

density_def::DensityMeasure, x) = densityof.f, x)

function logdensityof::DensityMeasure, x::Any)
integrand, μ_base = μ.f, μ.base

base_logval = logdensityof(μ_base, x)

T = typeof(base_logval)
U = logdensityof_rt(integrand, x)
R = promote_type(T, U)

# Don't evaluate base measure if integrand is zero or NaN
if isneginf(base_logval)
R(-Inf)
else
integrand_logval = logdensityof(integrand, x)
convert(R, integrand_logval + base_logval)::R
end
end


"""
rebase(μ, ν)
Expand Down
12 changes: 5 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,18 @@ using InverseFunctions: FunctionWithInverse
unwrap(f) = f
unwrap(f::FunctionWithInverse) = f.f


fcomp(f, g) = fchain(g, f)
fcomp(::typeof(identity), g) = g
fcomp(f, ::typeof(identity)) = f
fcomp(::typeof(identity), ::typeof(identity)) = identity

near_neg_inf(::Type{T}) where {T<:Real} = T(-1E38) # Still fits into Float32

near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32

isneginf(x) = isinf(x) && x < 0
isposinf(x) = isinf(x) && x > 0
isneginf(x) = isinf(x) && x < zero(x)
isposinf(x) = isinf(x) && x > zero(x)

isapproxzero(x::T) where T<:Real = x zero(T)
isapproxzero(x::T) where {T<:Real} = x zero(T)
isapproxzero(A::AbstractArray) = all(isapproxzero, A)

isapproxone(x::T) where T<:Real = x one(T)
isapproxone(x::T) where {T<:Real} = x one(T)
isapproxone(A::AbstractArray) = all(isapproxone, A)
16 changes: 16 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@ end
end
end

@testset "logdensityof" begin
f1 = let A=randn(Float32, 3,3); x -> sum(A*x); end
f2 = x -> sqrt(abs(sum(x)))
f3 = x -> 2 * sum(x)
f4 = x -> sum(sqrt.(abs.(x)))
m = @inferred ∫exp(f1, ∫exp(f2, ∫exp(f3, ∫exp(f4, StdUniform()^3))))

for x in [
Float32[0.7, 0.2, 0.5],
Float32[-0.7, 0.2, 0.5],
]
@test @inferred(logdensityof(m, x)) isa Float32
@test logdensityof(m, x) f1(x) + f2(x) + f3(x) + f4(x) + logdensityof(StdUniform()^3, x)
end
end

@testset "logdensity_rel" begin
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 0.0) == Inf
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 1.0) == -Inf
Expand Down

0 comments on commit 113196a

Please sign in to comment.