Skip to content

Commit

Permalink
STASH bind
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Nov 7, 2023
1 parent e3c23c0 commit 80913b1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import Base.iterate
import ConstructionBase
using ConstructionBase: constructorof
using IntervalSets
using OneTwoMany: secondarg
using OneTwoMany: firstarg, secondarg

using PrettyPrinting
const Pretty = PrettyPrinting
Expand Down
39 changes: 21 additions & 18 deletions src/combinators/bind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ See also [`mbind`](@ref).
function mkernel end
export mkernel

@inline mkernel(f_β::MKernel) = f_β
@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c)

@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c)
@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β


"""
struct MeasureBase.MKernel <: Function
Expand All @@ -45,12 +39,20 @@ Represents a generalized monatic transition kernel.
User code should not create instances of `MKernel` directly, but should call
[`mkernel`](@ref) instead.
"""
struct MKernel
f_β::FK
struct MKernel{FT,FC} <: Function
f_β::FT
f_c::FC
end


@inline mkernel(f_β::MKernel) = f_β
@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c)

@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c)
@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β



@doc raw"""
mbind(f_β, α::AbstractMeasure, f_c = OneTwoMany.secondarg)
mbind(f_β::MeasureBase.MKernel, α::AbstractMeasure)
Expand Down Expand Up @@ -102,7 +104,7 @@ The measure `α` that went into the bind can be retrieved via
Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)`
can be unambiguously split into `a` and `b` again, knowing `α`. This is
currently implemented for `f_c` that is either tuple or `=>`/`Pair` (these
currently implemented for `f_c` that is either `tuple` or `=>`/`Pair` (these
work for any combination of variate types), `vcat` (for tuple- or vector-like
variates) and `merge` (`NamedTuple` variates).
[`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to
Expand Down Expand Up @@ -152,19 +154,20 @@ export mbind

@inline mbind(f_β) = Base.Fix1(mbind, f_β)

@inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c)

#@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary ---
@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, α, f_c)
@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, asmeasure(α), f_c)

@inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c)
F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c)
Bind{F,M,G}(f_β, α, f_c)
end

function _generic_mbind_impl(f_β, α::Dirac, f_c)
mcombine(f_c, α, f_β.x))
end
@inline _generic_mbind_impl(f_β, α::Dirac, f_c) = mcombine(f_c, α, f_β.x))

@inline _generic_mbind_impl(@nospecialize(f_β), α::AbstractMeasure, ::typeof(firstarg)) = α
@inline _generic_mbind_impl(@nospecialize(f_β), α::Dirac, ::typeof(firstarg)) = α

@inline _generic_mbind_impl(f_k::MKernel, α::AbstractMeasure, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c)
@inline _generic_mbind_impl(f_k::MKernel, α::Dirac, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c)


"""
Expand All @@ -175,8 +178,8 @@ Represents a monatic bind resp. a mbind in general.
User code should not create instances of `Bind` directly, but should call
[`mbind`](@ref) instead.
"""
struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure
f_β::FK
struct Bind{FT,M<:AbstractMeasure,FC} <: AbstractMeasure
f_β::FT
α::M
f_c::FC
end
Expand Down
5 changes: 2 additions & 3 deletions src/combinators/combined.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ export mcombine

@inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β)

@inline _generic_mcombine_impl_stage1(::typeof(first), α::AbstractMeasure, β::AbstractMeasure) = α
@inline _generic_mcombine_impl_stage1(::typeof(getsecond), α::AbstractMeasure, β::AbstractMeasure) = β
@inline _generic_mcombine_impl_stage1(::typeof(last), α::AbstractMeasure, β::AbstractMeasure) = β
@inline _generic_mcombine_impl_stage1(::typeof(firstarg), α::AbstractMeasure, β::AbstractMeasure) = α
@inline _generic_mcombine_impl_stage1(::typeof(secondarg), α::AbstractMeasure, β::AbstractMeasure) = β

@inline function _generic_mcombine_impl_stage1(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure)
productmeasure((α, β))
Expand Down

0 comments on commit 80913b1

Please sign in to comment.