Skip to content

Commit

Permalink
Merge pull request #342 from calebweinreb/parallel_hmm_posterior_sample
Browse files Browse the repository at this point in the history
Parallel hmm posterior sample
  • Loading branch information
slinderman authored Jun 18, 2024
2 parents 46fb233 + 7978d13 commit a6b85ba
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 6 deletions.
3 changes: 2 additions & 1 deletion dynamax/hidden_markov_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from dynamax.hidden_markov_model.inference import compute_transition_probs

from dynamax.hidden_markov_model.parallel_inference import hmm_filter as parallel_hmm_filter
from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother
from dynamax.hidden_markov_model.parallel_inference import hmm_smoother as parallel_hmm_smoother
from dynamax.hidden_markov_model.parallel_inference import hmm_posterior_sample as parallel_hmm_posterior_sample
30 changes: 30 additions & 0 deletions dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools as it
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
import dynamax.hidden_markov_model.inference as core
import dynamax.hidden_markov_model.parallel_inference as parallel

Expand Down Expand Up @@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3):
posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1)


def test_parallel_posterior_sample(
key=0, num_timesteps=5, num_states=2, eps=1e-3,
num_samples=1000000, num_iterations=5
):
if isinstance(key, int):
key = jr.PRNGKey(key)

max_unique_size = 1 << num_timesteps

def iterate_test(key_iter):
keys_iter = jr.split(key_iter, num_samples)
args = random_hmm_args(key_iter, num_timesteps, num_states)

# Sample sequences from posterior
state_seqs = vmap(parallel.hmm_posterior_sample, (0, None, None, None), (0, 0))(keys_iter, *args)[1]
unique_seqs, counts = jnp.unique(state_seqs, axis=0, size=max_unique_size, return_counts=True)
blj_sample = counts / counts.sum()

# Compute joint probabilities
blj = jnp.exp(big_log_joint(*args))
blj = jnp.ravel(blj / blj.sum())

# Compare the joint distributions
return jnp.allclose(blj_sample, blj, rtol=0, atol=eps)

keys = jr.split(key, num_iterations)
assert iterate_test(keys[0])
92 changes: 87 additions & 5 deletions dynamax/hidden_markov_model/parallel_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import jax.numpy as jnp
import jax.random as jr
from jax import lax, vmap, value_and_grad
from jaxtyping import Array, Float
from jaxtyping import Array, Float, Int
from typing import NamedTuple, Union
from functools import partial

from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered

class Message(NamedTuple):
#---------------------------------------------------------------------------#
# Filtering #
#---------------------------------------------------------------------------#

class FilterMessage(NamedTuple):
"""Filtering associative scan elements.
Attributes:
A: $p(z_j \mid z_i)$
log_b: $\log P(y_{i+1}, ..., y_j \mid z_i)$
"""
A: Float[Array, "num_timesteps num_states num_states"]
log_b: Float[Array, "num_timesteps num_states"]

Expand Down Expand Up @@ -43,15 +55,15 @@ def marginalize(m_ij, m_jk):
A_ij_cond, lognorm = _condition_on(m_ij.A, m_jk.log_b)
A_ik = A_ij_cond @ m_jk.A
log_b_ik = m_ij.log_b + lognorm
return Message(A=A_ik, log_b=log_b_ik)
return FilterMessage(A=A_ik, log_b=log_b_ik)


# Initialize the messages
A0, log_b0 = _condition_on(initial_probs, log_likelihoods[0])
A0 *= jnp.ones((K, K))
log_b0 *= jnp.ones(K)
A1T, log_b1T = vmap(_condition_on, in_axes=(None, 0))(transition_matrix, log_likelihoods[1:])
initial_messages = Message(
initial_messages = FilterMessage(
A=jnp.concatenate([A0[None, :, :], A1T]),
log_b=jnp.vstack([log_b0, log_b1T])
)
Expand All @@ -72,6 +84,11 @@ def marginalize(m_ij, m_jk):
predicted_probs=predicted_probs)


#---------------------------------------------------------------------------#
# Smoothing #
#---------------------------------------------------------------------------#


def hmm_smoother(initial_probs: Float[Array, "num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
Expand Down Expand Up @@ -109,4 +126,69 @@ def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods):
initial_probs=smoothed_probs[0],
smoothed_probs=smoothed_probs,
trans_probs=trans_probs
)
)


#---------------------------------------------------------------------------#
# Sampling #
#---------------------------------------------------------------------------#
"""Associative scan elements $E_ij$ are vectors specifying a sample::
$z_j ~ p(z_j \mid z_i)$
for each possible value of $z_i$.
"""

def _initialize_sampling_messages(rng, transition_matrix, filtered_probs):
"""Preprocess filtering output to construct input for sampling assocative scan."""

T, K = filtered_probs.shape
rngs = jr.split(rng, T)

def _last_message(rng, probs):
state = jr.choice(rng, K, p=probs)
return jnp.repeat(state, K)

@vmap
def _generic_message(rng, probs):
smoothed_probs = probs * transition_matrix.T
smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1)
return vmap(lambda p: jr.choice(rng, K, p=p))(smoothed_probs)

En = _last_message(rngs[-1], filtered_probs[-1])
Et = _generic_message(rngs[:-1], filtered_probs[:-1])
return jnp.concatenate([Et, En[None]])


def hmm_posterior_sample(rng: jr.PRNGKey,
initial_distribution: Float[Array, "num_states"],
transition_matrix: Float[Array, "num_states num_states"],
log_likelihoods: Float[Array, "num_timesteps num_states"]
) -> Int[Array, "num_timesteps"]:
r"""Sample a sequence of hidden states from the posterior.
Args:
rng: random number generator
initial_distribution: $p(z_1 \mid u_1, \theta)$
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
Returns:
log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$
states: sequence of hidden states $z_{1:T}$
"""
T, K = log_likelihoods.shape

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods)
log_normalizer = post.marginal_loglik
filtered_probs = post.filtered_probs

@vmap
def _operator(E_jk, E_ij):
return jnp.take(E_ij, E_jk)

initial_messages = _initialize_sampling_messages(rng, transition_matrix, filtered_probs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)
states = final_messages[:,0]
return log_normalizer, states

0 comments on commit a6b85ba

Please sign in to comment.