Skip to content

Commit

Permalink
added test for parallel hmm sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Caleb Weinreb authored and Caleb Weinreb committed Sep 5, 2023
1 parent 45af584 commit 7978d13
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools as it
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
import dynamax.hidden_markov_model.inference as core
import dynamax.hidden_markov_model.parallel_inference as parallel

Expand Down Expand Up @@ -285,3 +286,32 @@ def test_parallel_smoother(key=0, num_timesteps=100, num_states=3):
posterior = core.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
posterior2 = parallel.hmm_smoother(initial_probs, transition_matrix, log_likelihoods)
assert jnp.allclose(posterior.smoothed_probs, posterior2.smoothed_probs, atol=1e-1)


def test_parallel_posterior_sample(
key=0, num_timesteps=5, num_states=2, eps=1e-3,
num_samples=1000000, num_iterations=5
):
if isinstance(key, int):
key = jr.PRNGKey(key)

max_unique_size = 1 << num_timesteps

def iterate_test(key_iter):
keys_iter = jr.split(key_iter, num_samples)
args = random_hmm_args(key_iter, num_timesteps, num_states)

# Sample sequences from posterior
state_seqs = vmap(parallel.hmm_posterior_sample, (0, None, None, None), (0, 0))(keys_iter, *args)[1]
unique_seqs, counts = jnp.unique(state_seqs, axis=0, size=max_unique_size, return_counts=True)
blj_sample = counts / counts.sum()

# Compute joint probabilities
blj = jnp.exp(big_log_joint(*args))
blj = jnp.ravel(blj / blj.sum())

# Compare the joint distributions
return jnp.allclose(blj_sample, blj, rtol=0, atol=eps)

keys = jr.split(key, num_iterations)
assert iterate_test(keys[0])

0 comments on commit 7978d13

Please sign in to comment.