Skip to content

Commit

Permalink
update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
SaremS committed Sep 30, 2024
1 parent 2f6c920 commit 3c3b88e
Show file tree
Hide file tree
Showing 23 changed files with 6,026 additions and 5 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Trigger container build in blog repo

on:
push:
branches:
- master

jobs:
trigger-workflow:
runs-on: ubuntu-latest
steps:
- name: Trigger workflow in another repository
run: |
curl -X POST \
-H "Authorization: token ${{ secrets.WORKFLOW_PAT }}" \
-H "Accept: application/vnd.github.v3+json" \
https://api.github.com/repos/sarems/blog/actions/workflows/main.yml/dispatches \
-d '{"ref":"master"}'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
181 changes: 181 additions & 0 deletions .ipynb_checkpoints/Censored Data-checkpoint.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

292 changes: 292 additions & 0 deletions .ipynb_checkpoints/Julia HMM Scaled-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d57b391d-84b4-4447-9406-3ec503d83454",
"metadata": {},
"source": [
"---\n",
"title: \"Scaled forward-backward algorithm + AutoDiff for optimizing the observation distributions\"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c77c4abe-4ef2-4278-a83a-c156b3386a57",
"metadata": {},
"outputs": [],
"source": [
"using Flux, Distributions, Zygote, ForwardDiff, Plots, StatsPlots, LinearAlgebra, Random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e0bc4d4b-329d-43ea-8093-fbafc3a3473b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"EM (generic function with 2 methods)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mutable struct HMM\n",
" μ\n",
" σ\n",
" \n",
" P\n",
"end\n",
"Flux.@functor HMM (μ, σ)\n",
"\n",
"function HMM(states)\n",
" μ = collect(range(-1,1, length=states))\n",
" σ = zeros(states)\n",
" \n",
" P = softmax(Matrix(Diagonal(ones(states))),dims=2)\n",
" \n",
" return HMM(μ,σ,P)\n",
"end\n",
"\n",
"\n",
"function filter(m::HMM, y, p_t)\n",
" \n",
" μ = m.μ\n",
" σ = exp.(m.σ)\n",
" P = m.P\n",
" \n",
" y_t = y[1]\n",
" \n",
" dists_t = Normal.(μ,σ)\n",
" pdfs = pdf.(dists_t, y_t)\n",
" \n",
" sumdist = (p_t.*pdfs)\n",
" p_tt = sumdist./sum(sumdist)\n",
" \n",
" p_tp1 = P*p_tt\n",
" \n",
" if length(y)>1\n",
" dists_tp1, p_t, p_ttp1 = filter(m,y[2:end],p_tp1)\n",
" return vcat(dists_t, dists_tp1), hcat(p_t,p_tp1), hcat(p_tt, p_ttp1)\n",
" else\n",
" return dists_t, p_t, p_tt\n",
" end\n",
"end\n",
"\n",
"function forward_normalized(m::HMM, y, α_tm1)\n",
" #https://github.com/mattjj/pyhsmm/blob/e6cfde5acb98401c2e727ca59a49ee0bfe86cf9d/pyhsmm/internals/hmm_states.py#L322\n",
" \n",
" μ = m.μ\n",
" σ = exp.(m.σ)\n",
" P = m.P\n",
" \n",
" y_t = y[1]\n",
" qsum = P*α_tm1\n",
" dists = Normal.(μ,σ)\n",
" lpdfs = logpdf.(dists,y_t)\n",
" \n",
" lpdf_max = maximum(lpdfs)\n",
" \n",
" α_t = qsum[:] .* exp.(lpdfs .- lpdf_max)\n",
" normalizer = sum(α_t)\n",
" \n",
" α_t_normed = α_t ./ normalizer\n",
" logtot_t = log(normalizer) + lpdf_max\n",
" \n",
" if length(y)>1\n",
" α_tp1_normed, logtot_tp1 = forward_normalized(m,y[2:end],α_t_normed)\n",
" return hcat(α_t_normed, α_tp1_normed), logtot_t + logtot_tp1\n",
" else\n",
" return α_t_normed, logtot_t\n",
" end\n",
"end\n",
"\n",
"\n",
"function backward_normalized(m::HMM, y, β_tp1)\n",
" #https://github.com/mattjj/pyhsmm/blob/e6cfde5acb98401c2e727ca59a49ee0bfe86cf9d/pyhsmm/internals/hmm_states.py#L295\n",
" \n",
" μ = m.μ\n",
" σ = exp.(m.σ)\n",
" P = m.P\n",
" \n",
" y_t = y[end]\n",
" \n",
" dists = Normal.(μ,σ)\n",
" lpdfs = logpdf.(dists,y_t)\n",
" \n",
" lpdf_max = maximum(lpdfs)\n",
" \n",
" β_t = transpose(P)*(β_tp1.*exp.(lpdfs.-lpdf_max))[:]\n",
" normalizer = sum(β_t)\n",
" \n",
" β_t_normed = β_t./normalizer\n",
" logtot_t = log(normalizer) + lpdf_max\n",
" \n",
" if length(y)>1\n",
" β_tm1_normed, logtot_tm1 = backward_normalized(m, y[1:end-1], β_t_normed)\n",
" return hcat(β_tm1_normed, β_t_normed), logtot_tm1 + logtot_t\n",
" else\n",
" return β_t_normed, logtot_t\n",
" end\n",
"end\n",
"\n",
"\n",
"\n",
"function likelihood(m::HMM,y,sps)\n",
" μ = m.μ\n",
" σ = exp.(m.σ)\n",
" \n",
" dists = Normal.(μ,σ)\n",
" \n",
" return mean(map(i->sum(sps[i].*logpdf.(dists,y[i])),1:length(y)))\n",
"end\n",
"\n",
"\n",
"function EM(m::HMM, y, p_0, n_iter = 50)\n",
" \n",
" for i in 1:n_iter\n",
" α, logtot_α = forward_normalized(m,y,p_0)\n",
" β, logtot_β = backward_normalized(m,y,p_0)\n",
" \n",
" αβ = α.*β \n",
" γ = αβ./sum(αβ,dims=1)\n",
" \n",
" \n",
" sps = Flux.unstack(γ,dims=2)\n",
"\n",
" ps, f = Flux.destructure(m)\n",
" \n",
" for _ in 1:50\n",
" grads = ForwardDiff.gradient(x -> -likelihood(f(x),y,sps), ps)\n",
" ps .-= 0.001.*grads\n",
" end\n",
" \n",
" newm = f(ps)\n",
" m.μ = newm.μ\n",
" m.σ = newm.σ\n",
" \n",
" Ps = Matrix(transpose(hcat([sum(γ[i:i, 1:end-1].*γ[:, 2:end],dims=2)[:] for i in 1:length(p_0)]...)))\n",
" Ps./=sum(Ps,dims=1)\n",
" \n",
" m.P = Ps\n",
" \n",
" if i%50==0\n",
" println(-likelihood(m,y,sps))\n",
" end\n",
" end\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "548eb4c7-c06d-4923-8a43-b88c1e09d90e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([-1.0, 0.0, 1.0, 0.0, 0.0, 0.0], Restructure(HMM, ..., 6))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Random.seed!(321)\n",
"y = vcat([vcat(0.5 .*randn(25).+3, randn(25), 0.5 .*randn(25).-3) for _ in 1:10]...)\n",
"\n",
"m = HMM(3)\n",
"ps, f = Flux.destructure(m)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a076f0ea-6894-474f-b07e-0f167e19883a",
"metadata": {},
"outputs": [],
"source": [
"α, logtot_α = forward_normalized(m,y,ones(3)./3)\n",
"β, logtot_β = backward_normalized(m,y,ones(3)./3)\n",
"\n",
"αβ = α.*β \n",
"state_probs = αβ./sum(αβ,dims=1)\n",
"\n",
"\n",
"mean_pred = sum(m.μ .* state_probs,dims=1)[:]\n",
"std_pred = sqrt.(sum(exp.(m.σ) .* state_probs,dims=1)[:])\n",
"\n",
"p1 = scatter(collect(1:length(y)), y, title = \"Smoothing distribution before training\", label = \"Data\",fmt=:png,\n",
" c=\"blue\", legend=:bottomleft, size = (1200,600))\n",
"plot!(p1, mean_pred, label = \"Predicted mean + 2 stddevs\", c=\"red\", ribbon = 2 .* std_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9b914d3-aec8-4c73-a279-6927aa07c36d",
"metadata": {},
"outputs": [],
"source": [
"EM(m,y,ones(3)./3,750)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d59f2b74-a719-4542-bd5d-39153114a5b5",
"metadata": {},
"outputs": [],
"source": [
"α, logtot_α = forward_normalized(m,y,ones(3)./3)\n",
"β, logtot_β = backward_normalized(m,y,ones(3)./3)\n",
"\n",
"αβ = α.*β \n",
"state_probs = αβ./sum(αβ,dims=1)\n",
"\n",
"\n",
"mean_pred = sum(m.μ .* state_probs,dims=1)[:]\n",
"std_pred = sqrt.(sum(exp.(m.σ) .* state_probs,dims=1)[:])\n",
"\n",
"p2 = scatter(collect(1:length(y)), y, title = \"Smoothing distribution after training\", label = \"Data\",fmt=:png,\n",
" c=\"blue\", legend=:bottomleft, size = (1200,600))\n",
"plot!(p2, mean_pred, label = \"Predicted mean + 2 stddevs\", c=\"red\", ribbon = 2 .* std_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a295f6a-71f3-4450-b1a2-755c56f0f83c",
"metadata": {},
"outputs": [],
"source": [
"plot(p1,p2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.8.1",
"language": "julia",
"name": "julia-1.8"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 3c3b88e

Please sign in to comment.