Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit CategoricalHMM with available data? #345

Open
AndyWeasley2004 opened this issue Oct 12, 2023 · 1 comment
Open

Fit CategoricalHMM with available data? #345

AndyWeasley2004 opened this issue Oct 12, 2023 · 1 comment

Comments

@AndyWeasley2004
Copy link

I have walked through the example usage in: https://probml.github.io/dynamax/notebooks/hmm/casino_hmm_learning.html

However, the params and promps are all generated by initialize function in the example, and if I have ready-to-use lists, one input lists (X in general ML), and one label list (y in general ML), how could I use the fit function?

I know this question could be naive, but I'm relatively new to Python. I greatly appreciate someone could help.

@gileshd
Copy link
Collaborator

gileshd commented Nov 1, 2023

Hi @AndyWeasley2004,

I'm not sure I have totally understood your question correctly but perhaps the following might be helpful.

There is an example at the start of the demo of selecting the values of the parameters:

num_states = 2      # two types of dice (fair and loaded)
num_emissions = 1   # only one die is rolled at a time
num_classes = 6     # each die has six faces

initial_probs = jnp.array([0.5, 0.5])
transition_matrix = jnp.array([[0.95, 0.05], 
                               [0.10, 0.90]])
emission_probs = jnp.array([[1/6,  1/6,  1/6,  1/6,  1/6,  1/6],    # fair die
                            [1/10, 1/10, 1/10, 1/10, 1/10, 5/10]])  # loaded die


# Construct the HMM
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
                           transition_matrix=transition_matrix,
                           emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))

In this example the values of the parameters are determined by the values in the arrays initial_probs, transition_matrix, and emission_probs.

What the hmm.initialize method is doing in this example is taking the arrays we have defined and converting them into the appropriate parameter objects (an instance of ParamsCategoricalHMM).

If you wanted to use your own parameter values then you can convert them into jax arrays (jnp.array(param_list)) and pass that into the initialize method as above.

The initialize method also allows you to sample random values for the parameters by passing a key. This approach is used later on in the demo for example:

key = jr.PRNGKey(0)
em_params, em_param_props = hmm.initialize(key)
em_params, log_probs = hmm.fit_em(em_params, 
                                  em_param_props, 
                                  batch_emissions, 
                                  num_iters=400)

It sounds like your "label list" might correspond to the hmm emissions, in which case you can pass that to fit function as the emissions argument (the array batch_emissions is being passed as the value for that argument in the example above).

I am not entirely sure what you mean by your "input lists" but if you can provide some more details I am happy to see if I can help further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants