Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for HMMs with num_states=1 #380

Open
umeshksingla opened this issue Oct 1, 2024 · 1 comment
Open

Support for HMMs with num_states=1 #380

umeshksingla opened this issue Oct 1, 2024 · 1 comment

Comments

@umeshksingla
Copy link

I am trying to fit various HMM classes (LinearRegressionHMM, or GaussianHMM) to my data but it does not let me pass num_states=1. For num_states > 2, everything works as expected. I wanted to know whether no support for num_states=1 is the intended behavior.

It's easy enough to write code for simple linear regression outside dynamax, however, it still makes the comparison with num_states>2 cases error-prone (as one might be using different constants in log-likelihood calculations, etc.).

If it helps, the error occurs while trying to initialize the Dirichlet distribution.

File "/Users/us/project/fitting.py", line 20, in fitEM
    params, props = hmm.initialize(key)
  File "/Users/us/dynaenv/lib/python3.10/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py", line 649, in initialize
    params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
  File "/Users/us/dynaenv/lib/python3.10/site-packages/dynamax/hidden_markov_model/models/initial.py", line 45, in initialize
    initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)
...
ValueError: Argument `concentration` must have `event_size` at least 2.
@gileshd
Copy link
Collaborator

gileshd commented Oct 3, 2024

Hi @umeshksingla! This is an interesting scenario.

We do make the assumption that there are at least 2 hidden states (I suppose a model with only one hidden state is not really a HMM).

In practice, some behaviours seem to work fine with one hidden state. However, as you have found, we run into an error whenever we try to interact with tfd.Dirichlet distributions as they require at least two states. In order to resolve this we would need to manually check for scenarios with only one state and treat them as a special case.

I am not totally sure that we want to add complexity to handle this, somewhat niche and potentially out of scope, use case however perhaps it is a good idea and if not, at the very least, we should indicate that num_states should be >=2 and improve the error messages here.

For your present purposes you can avoid the call to tfd.Dirichlet during initialization by manually specifying the initial probabilities and transition matrix. This is straightforward as the relevant parameters are constrained to only take a specific value when there is only one state. For instance,

from jax import numpy as jnp
from jax import random as jr
from dynamax.hidden_markov_model import GaussianHMM

hmm = GaussianHMM(num_states=1, emission_dim=1)

initial_probs = jnp.array([1.0])
transition_matrix = jnp.array([[1.0]])
params, props = hmm.initialize(initial_probs=initial_probs, transition_matrix=transition_matrix)

z, x = hmm.sample(params, key = jr.PRNGKey(0), num_timesteps=10)
# z is Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32) 

You shoud be able to use this model for sampling as normal.

However parameter learning (e.g. .fit_em(), .fit_sgd()) will not work as the Dirichlet distribution is called during the fitting process:

hmm = GaussianHMM(num_states=1, emission_dim=1)
initial_probs = jnp.array([1.0])
transition_matrix = jnp.array([[1.0]])
params, props = hmm.initialize(initial_probs=initial_probs, transition_matrix=transition_matrix)

z, x = hmm.sample(params, key=jr.PRNGKey(0), num_timesteps=100)

params_inf, props = hmm.initialize(key=jr.PRNGKey(10) , initial_probs=initial_probs, transition_matrix=transition_matrix)

props.initial.probs.trainable = False
props.transitions.transition_matrix.trainable = False
try:
    hmm.fit_em(params_inf, props, emissions=x)
except ValueError as e:
    print(f"Error: {e}")

One work-around for this is to make a model with num_states=2 but specify the initial state distribution and transition matrix so that the model will behave as if it has only one state.

Here is an example:

from jax import numpy as jnp
from jax import random as jr
from dynamax.hidden_markov_model import GaussianHMM

hmm = GaussianHMM(num_states=2, emission_dim=1)

initial_probs = jnp.array([1.0, 0.])
transition_matrix = jnp.array([[1.0, 0.], [1.0, 0.0]])
params, props = hmm.initialize(key=jr.PRNGKey(0), initial_probs=initial_probs, transition_matrix=transition_matrix)

params_inf, props = hmm.initialize(key=jr.PRNGKey(100), initial_probs=initial_probs, transition_matrix=transition_matrix)

props.initial.probs.trainable = False
props.transitions.transition_matrix.trainable = False

hmm.fit_em(params_inf, props, emissions=x)

This might get okay parameter results however the logprob calculations aren't fond of this setup and you may get jnp.inf or jnp.nan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants