Skip to content

Commit

Permalink
Change GNN to NGMM and VMM to NVMMM
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed Dec 14, 2023
1 parent 7512913 commit 98b64be
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
26 changes: 13 additions & 13 deletions src/GaussianMixtureLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ using LinearAlgebra
using ConditionalDensityLayers: AbstractConditionalDensityLayer
import StatsBase

export GNNLayer
export GNNSettings
export NGMMLayer
export NGMMSettings

struct GNNLayer
struct NGMMLayer
central_network::Chain
weight_network::Chain
centroid_network::Chain
std_network::Chain
K::Int
N_dims::Int
end
Flux.@functor GNNLayer
Flux.@functor NGMMLayer

function GNNSettings(sizeof_conditionvector; K = 20, N_dims = 3, numembeddings = 256, numhiddenlayers = 6, σ = relu, p = 0.05f0)
function NGMMSettings(sizeof_conditionvector; K = 20, N_dims = 3, numembeddings = 256, numhiddenlayers = 6, σ = relu, p = 0.05f0)
return (
K = K,
N_dims = N_dims,
Expand All @@ -32,9 +32,9 @@ end



function GNNLayer(settings::NamedTuple)
function NGMMLayer(settings::NamedTuple)
s = settings
return GNNLayer(
return NGMMLayer(
K = s.K,
N_dims = s.N_dims,
sizeof_conditionvector = s.sizeof_conditionvector,
Expand All @@ -45,11 +45,11 @@ function GNNLayer(settings::NamedTuple)
)
end

function GNNLayer(settings_vector::Vector{<: NamedTuple})
return [GNNLayer(settings) for settings in settings_vector]
function NGMMLayer(settings_vector::Vector{<: NamedTuple})
return [NGMMLayer(settings) for settings in settings_vector]
end

function GNNLayer(; K::Integer, N_dims::Integer, sizeof_conditionvector::Integer, numembeddings::Integer = 256, numhiddenlayers::Integer = 20, σ = relu, p = 0.05f0)
function NGMMLayer(; K::Integer, N_dims::Integer, sizeof_conditionvector::Integer, numembeddings::Integer = 256, numhiddenlayers::Integer = 20, σ = relu, p = 0.05f0)
lays = []
for i in 2:3*numhiddenlayers
# alternate dense -> layernorm -> dense -> dropout -> ...
Expand All @@ -73,10 +73,10 @@ function GNNLayer(; K::Integer, N_dims::Integer, sizeof_conditionvector::Integer
std_network = Chain(
Dense(numembeddings => K, softplus)
)
GNNLayer(central_network, weight_network, centroid_network, std_network, K, N_dims)
NGMMLayer(central_network, weight_network, centroid_network, std_network, K, N_dims)
end

function (g::GNNLayer)(conditionvector)
function (g::NGMMLayer)(conditionvector)
batch = size(conditionvector)[2:end]
S_emb = g.central_network(conditionvector)
weights = Flux.softmax(reshape(g.weight_network(S_emb), g.K, batch...), dims = 1)
Expand All @@ -99,7 +99,7 @@ function NLLIsotropicGMM(x, w, μ, σ)
return -sum(log_likelihood)./(K .* batch)
end

function loss(g::GNNLayer, Y::AbstractVecOrMat{Float32}, conditionvector::AbstractVecOrMat{Float32})
function loss(g::NGMMLayer, Y::AbstractVecOrMat{Float32}, conditionvector::AbstractVecOrMat{Float32})
w, μ, σ = get_gmm_params(g, conditionvector)
return NLLIsotropicGMM(Y, w, μ, σ)
end
Expand Down
30 changes: 15 additions & 15 deletions src/VonMisesMixtureLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ using Distributions: log2π
using Distributions
import StatsBase

export VMMLayer
export NVMMMLayer
export VonMisesNucleusSample
export VMMSettings
export NVMMMSettings

struct VMMLayer
struct NVMMMLayer
central_network::Chain
weight_network::Chain
μx_network::Chain
μy_network::Chain
K::Int
N_dims::Int
end
Flux.@functor VMMLayer
Flux.@functor NVMMMLayer

function VMMSettings(sizeof_conditionvector; K = 20, N_dims = 3, numembeddings = 256, numhiddenlayers = 6, σ = relu, p = 0.05f0)
function NVMMMSettings(sizeof_conditionvector; K = 20, N_dims = 3, numembeddings = 256, numhiddenlayers = 6, σ = relu, p = 0.05f0)
return (
K = K,
N_dims = N_dims,
Expand All @@ -33,9 +33,9 @@ function VMMSettings(sizeof_conditionvector; K = 20, N_dims = 3, numembeddings =
)
end

function VMMLayer(settings::NamedTuple)
function NVMMMLayer(settings::NamedTuple)
s = settings
return VMMLayer(
return NVMMMLayer(
K = s.K,
N_dims = s.N_dims,
sizeof_conditionvector = s.sizeof_conditionvector,
Expand All @@ -46,11 +46,11 @@ function VMMLayer(settings::NamedTuple)
)
end

function VMMLayer(settings_vector::Vector{<: NamedTuple})
return [VMMLayer(settings) for settings in settings_vector]
function NVMMMLayer(settings_vector::Vector{<: NamedTuple})
return [NVMMMLayer(settings) for settings in settings_vector]
end

function VMMLayer(; K::Integer, N_dims, sizeof_conditionvector::Integer, numembeddings::Integer = 256, numhiddenlayers::Integer = 6, σ = relu, p = 0.05f0)
function NVMMMLayer(; K::Integer, N_dims, sizeof_conditionvector::Integer, numembeddings::Integer = 256, numhiddenlayers::Integer = 6, σ = relu, p = 0.05f0)
lays = []
for i in 2:3*numhiddenlayers
if i % 3 == 1
Expand All @@ -73,7 +73,7 @@ function VMMLayer(; K::Integer, N_dims, sizeof_conditionvector::Integer, numembe
weight_network = Chain(
Dense(floor(Int, numembeddings) => K)
)
VMMLayer(central_network, weight_network, μx_network, μy_network, K, N_dims)
NVMMMLayer(central_network, weight_network, μx_network, μy_network, K, N_dims)
end


Expand Down Expand Up @@ -123,7 +123,7 @@ Gets the Von-Mises parameters for the mixture from an embedding vector S.
κ = μx^2 + μy^2, and
μ = atan(μx, μy)
"""
function get_vmm_params(v::VMMLayer, S)
function get_vmm_params(v::NVMMMLayer, S)
batch = size(S)[2:end]
S_emb = v.central_network(S)
N_dims = v.N_dims
Expand All @@ -138,7 +138,7 @@ end
"""
Computes the mixture-NLL loss of the Von-Mises Mixture Layer, given an embedding vector S and true values θ.
"""
function loss(v::VMMLayer, θ, S)
function loss(v::NVMMMLayer, θ, S)
μ, κ, w = get_vmm_params(v, S)
return NLLVonMisesMixture(θ, μ, κ, w), (θ,μ)
end
Expand Down Expand Up @@ -169,7 +169,7 @@ end
"""
Given an embedding vector S, generates N_samples samples from the resulting mixtures.
"""
function VMMSample(v::VMMLayer, S::AbstractVecOrMat{T}; N_samples::Integer = 1000) where T <: Real
function NVMMMSample(v::NVMMMLayer, S::AbstractVecOrMat{T}; N_samples::Integer = 1000) where T <: Real
μ, κ, w = get_vmm_params(v, S)
μ, κ, w = Float64.(μ), Float64.(κ), Float64.(w)
samps = [stack([VonMisesSample(μ[:,:,i], κ[:,:,i], w[:,i]) for j in 1:N_samples]) for i in axes(S,2)]
Expand Down Expand Up @@ -203,7 +203,7 @@ function VonMisesNucleusSample(μ, κ, w; N_samps = 10000, Pd = 0.8, same_dist_s
return dihs_sampled
end

function (v::VMMLayer)(S)
function (v::NVMMMLayer)(S)
return get_vmm_params(v, S)
end

Expand Down

0 comments on commit 98b64be

Please sign in to comment.