-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
My non-stationary Markov toy example isn't learning #363
Comments
I edited |
I found the bug. This is the correct version of def distribution(self, params, state, inputs=None):
cycle_index, state_index = state
return tfd.JointDistributionSequential([
tfd.Deterministic((cycle_index + 1) % self.cycle_dim),
tfd.Categorical(probs=params.transition_matrix[cycle_index.astype(jnp.uint32), state_index])
]) It is important to not increment the cycle_index that is being used to lookup into import numpy as np
from jaxtyping import Array, Float
from functools import partial
from typing import NamedTuple, Union, Tuple, Optional
import jax
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax.nn import one_hot
from jaxtyping import Array, Float
import optax
from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet
from dynamax.utils.utils import pytree_sum
from dynamax.types import Scalar
import matplotlib.pyplot as plt
from dynamax.hidden_markov_model.models.abstractions import HMM
# from dynamax.hidden_markov_model.models.abstractions import HMMInitialState, HMMEmissions, HMMTransitions
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.categorical_hmm import CategoricalHMMEmissions, StandardHMMTransitions
# from dynamax.hidden_markov_model.models.categorical_hmm import ParamsCategoricalHMM, ParamsCategoricalHMMEmissions, ParamsStandardHMMTransitions
from dynamax.hidden_markov_model.inference import hmm_two_filter_smoother
from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.parameters import ParameterSet, PropertySet
from dynamax.types import PRNGKey, Scalar
from dynamax.utils.optimize import run_sgd
from dynamax.utils.utils import ensure_array_has_batch_dim
class CyclingHMMInitialState(StandardHMMInitialState):
"""Abstract class for HMM initial distributions.
"""
def __init__(self,
num_states,
initial_probs_concentration=1.1):
"""
Args:
initial_probabilities[k]: prob(hidden(1)=k)
"""
self.num_states = num_states
self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states)
def distribution(self, params, inputs=None):
return tfd.JointDistributionSequential([
tfd.Deterministic(0), # Always start at the 0th transition matrix in the cycle
tfd.Categorical(probs=params.probs),
])
class ParamsCyclingHMMTransitions(NamedTuple):
transition_matrix: Union[Float[Array, "cycle_dim state_dim state_dim"], ParameterProperties]
class CyclingHMMTransitions(StandardHMMTransitions):
r"""Standard model for HMM transitions.
We place a Dirichlet prior over the rows of the transition matrix $A$,
$$A_k \sim \mathrm{Dir}(\beta 1_K + \kappa e_k)$$
where
* $1_K$ denotes a length-$K$ vector of ones,
* $e_k$ denotes the one-hot vector with a 1 in the $k$-th position,
* $\beta \in \mathbb{R}_+$ is the concentration, and
* $\kappa \in \mathbb{R}_+$ is the `stickiness`.
"""
def __init__(self, cycle_dim, num_states, concentration=1.1, stickiness=0.0):
"""
Args:
transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j)
"""
self.cycle_dim = cycle_dim
self.num_states = num_states
concentration = \
concentration * jnp.ones((num_states, num_states)) + \
stickiness * jnp.eye(num_states)
concentration = jnp.tile(jnp.expand_dims(concentration, axis=0), reps=(self.cycle_dim, 1, 1)) # todo:
self.concentration = concentration
def distribution(self, params, state, inputs=None):
cycle_index, state_index = state
return tfd.JointDistributionSequential([
tfd.Deterministic((cycle_index + 1) % self.cycle_dim),
tfd.Categorical(probs=params.transition_matrix[cycle_index.astype(jnp.uint32), state_index])
])
def initialize(self, key=None, method="prior", transition_matrix=None):
"""Initialize the model parameters and their corresponding properties.
Args:
key (_type_, optional): _description_. Defaults to None.
method (str, optional): _description_. Defaults to "prior".
transition_matrix (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
if transition_matrix is None:
this_key, key = jr.split(key)
transition_matrix = tfd.Dirichlet(self.concentration).sample(seed=this_key)
# Package the results into dictionaries
params = ParamsCyclingHMMTransitions(transition_matrix=transition_matrix)
props = ParamsCyclingHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
return params, props
def collect_suff_stats(self, params, posterior, inputs=None):
# return posterior.trans_probs
num_timesteps = posterior.trans_probs.shape[0]
trans_probs = jnp.stack([
posterior.trans_probs[jnp.arange(i, num_timesteps, step=self.cycle_dim)].sum(axis=0) # todo:
for i in range(self.cycle_dim)])
return trans_probs
class ParamsCyclingHMMEmissions(NamedTuple):
pass
class ParamsCyclingHMM(NamedTuple):
initial: ParamsStandardHMMInitialState
transitions: ParamsCyclingHMMTransitions
emissions: None
class CyclingHMMEmissions(CategoricalHMMEmissions):
def __init__(self,
num_states,
emission_dim):
self.num_states = num_states
self.emission_dim = emission_dim
@property
def emission_shape(self):
return [(self.emission_dim,), (self.emission_dim,)] # todo:
# return (2, self.emission_dim,)
# return ((self.emission_dim,), (self.emission_dim,))
# return (self.emission_dim,)
def distribution(self, params, state, inputs=None):
cycle_index, state_index = state
return tfd.JointDistributionSequential([
tfd.Deterministic([cycle_index]),
tfd.Deterministic([state_index])
])
# return tfd.Deterministic(state_index) # todo:
# return tfd.Deterministic([state_index]) # todo:
# return tfd.Independent(
# tfd.Deterministic([state_index]),
# reinterpreted_batch_ndims=0)
def log_prior(self, params):
# todo: the emissions are fully deterministic, so the log prior is 0, right?
return 0
def _compute_conditional_logliks(self, params, emissions, inputs=None):
# todo:
a = emissions[1].reshape((-1,))
a = jnp.round(a).astype(jnp.uint32)
a = one_hot(a, num_classes=self.num_states)
return jnp.where(a, jnp.zeros_like(a), jnp.full_like(a, fill_value=-jnp.inf))
def initialize(self, key=jr.PRNGKey(0), method="prior"):
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Note: in the future we may support more initialization schemes, like K-Means.
Args:
key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to jr.PRNGKey(0).
method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
Returns:
params: nested dataclasses of arrays containing model parameters.
props: a nested dictionary of ParameterProperties to specify parameter constraints and whether or not they should be trained.
"""
# Add parameters to the dictionary
params = ParamsCyclingHMMEmissions()
props = ParamsCyclingHMMEmissions()
return params, props
def collect_suff_stats(self, params, posterior, emissions, inputs=None):
# todo: the emissions are fully deterministic, so return empty dict?
return dict()
def m_step(self, params, props, batch_stats, m_step_state):
# todo: the emissions are fully deterministic, so nothing to maximize?
return params, m_step_state
class CyclingCategoricalHMM(HMM):
r"""An HMM with conditionally independent categorical emissions.
Let $y_t \in \{1,\ldots,C\}^N$ denote a vector of $N$ conditionally independent
categorical emissions from $C$ classes at time $t$. In this model,the emission
distribution is,
$$p(y_t \mid z_t, \theta) = \prod_{n=1}^N \mathrm{Cat}(y_{tn} \mid \theta_{z_t,n})$$
$$p(\theta) = \prod_{k=1}^K \prod_{n=1}^N \mathrm{Dir}(\theta_{k,n}; \gamma 1_C)$$
with $\theta_{k,n} \in \Delta_C$ for $k=1,\ldots,K$ and $n=1,\ldots,N$ are the
*emission probabilities* and $\gamma$ is their prior concentration.
:param cycle_dim: number of $K-K$ transition matrices to cycle through
:param num_states: number of discrete states $K$
:param emission_dim: number of conditionally independent emissions $N$
:param initial_probs_concentration: $\alpha$
:param transition_matrix_concentration: $\beta$
:param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.
"""
def __init__(self,
cycle_dim: int,
num_states: int,
emission_dim: int,
initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]] = 1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]] = 1.1,
transition_matrix_stickiness: Scalar = 0.0):
self.cycle_dim = cycle_dim
self.emission_dim = emission_dim
initial_component = CyclingHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
transition_component = CyclingHMMTransitions(cycle_dim, num_states,
concentration=transition_matrix_concentration,
stickiness=transition_matrix_stickiness)
emission_component = CyclingHMMEmissions(num_states, emission_dim)
super().__init__(num_states, initial_component, transition_component, emission_component)
def initialize(self,
key: jr.PRNGKey = jr.PRNGKey(0),
method: str = "prior",
initial_probs: Optional[Float[Array, "num_states"]] = None,
transition_matrix: Optional[Float[Array, "cycle_dim num_states num_states"]] = None,
) -> Tuple[ParameterSet, PropertySet]:
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
them set automatically. If any parameters are not specified, you must supply a PRNGKey.
Parameters will then be sampled from the prior (if `method==prior`).
Note: in the future we may support more initialization schemes, like K-Means.
Args:
key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to None.
method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
initial_probs (array, optional): manually specified initial state probabilities. Defaults to None.
transition_matrix (array, optional): manually specified transition matrix. Defaults to None.
Returns:
Model parameters and their properties.
"""
key1, key2, key3 = jr.split(key, 3)
params, props = dict(), dict()
params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method,
initial_probs=initial_probs)
params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method,
transition_matrix=transition_matrix)
params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method)
return ParamsCyclingHMM(**params), ParamsCyclingHMM(**props)
def e_step(self, params, emissions, inputs=None):
"""The E-step computes expected sufficient statistics under the
posterior. In the generic case, we simply return the posterior itself.
"""
initial_distribution, transition_matrix, log_likelihoods = self._inference_args(params, emissions, inputs)
transition_fn = lambda index: transition_matrix[index % self.cycle_dim]
posterior = hmm_two_filter_smoother(initial_distribution=initial_distribution, log_likelihoods=log_likelihoods,
transition_matrix=None, # None because we use `transition_fn`
transition_fn=transition_fn)
initial_stats = self.initial_component.collect_suff_stats(params.initial, posterior, inputs)
transition_stats = self.transition_component.collect_suff_stats(params.transitions, posterior, inputs)
emission_stats = self.emission_component.collect_suff_stats(params.emissions, posterior, emissions, inputs)
return (initial_stats, transition_stats, emission_stats), posterior.marginal_loglik
def fit_sgd(
self,
params: ParameterSet,
props: PropertySet,
emissions: Union[Float[Array, "num_timesteps emission_dim"],
Float[Array, "num_batches num_timesteps emission_dim"]],
inputs: Optional[Union[Float[Array, "num_timesteps input_dim"],
Float[Array, "num_batches num_timesteps input_dim"]]]=None,
optimizer: optax.GradientTransformation=optax.adam(1e-3),
batch_size: int=1,
num_epochs: int=50,
shuffle: bool=False,
key: jr.PRNGKey=jr.PRNGKey(0)
) -> Tuple[ParameterSet, Float[Array, "niter"]]:
r"""Compute parameter MLE/ MAP estimate using Stochastic Gradient Descent (SGD).
SGD aims to find parameters that maximize the marginal log probability,
$$\theta^\star = \mathrm{argmax}_\theta \; \log p(y_{1:T}, \theta \mid u_{1:T})$$
by minimizing the _negative_ of that quantity.
*Note:* ``emissions`` *and* ``inputs`` *can either be single sequences or batches of sequences.*
On each iteration, the algorithm grabs a *minibatch* of sequences and takes a gradient step.
One pass through the entire set of sequences is called an *epoch*.
Args:
params: model parameters $\theta$
props: properties specifying which parameters should be learned
emissions: one or more sequences of emissions
inputs: one or more sequences of corresponding inputs
optimizer: an `optax` optimizer for minimization
batch_size: number of sequences per minibatch
num_epochs: number of epochs of SGD to run
key: a random number generator for selecting minibatches
verbose: whether or not to show a progress bar
Returns:
tuple of new parameters and losses (negative scaled marginal log probs) over the course of SGD iterations.
"""
# Make sure the emissions and inputs have batch dimensions
batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape)
batch_inputs = ensure_array_has_batch_dim(inputs, self.inputs_shape)
unc_params = to_unconstrained(params, props)
def _loss_fn(unc_params, minibatch):
"""Default objective function."""
params = from_unconstrained(unc_params, props)
minibatch_emissions, minibatch_inputs = minibatch
num_timesteps = len(batch_emissions[0])
scale = num_timesteps / len(minibatch_emissions[0])
minibatch_lls = jax.vmap(partial(self.marginal_log_prob, params))(minibatch_emissions, minibatch_inputs)
lp = self.log_prior(params) + minibatch_lls.sum() * scale
return -lp / batch_emissions[0].size
dataset = (batch_emissions, batch_inputs)
unc_params, losses = run_sgd(_loss_fn,
unc_params,
dataset,
optimizer=optimizer,
batch_size=batch_size,
num_epochs=num_epochs,
shuffle=shuffle,
key=key)
params = from_unconstrained(unc_params, props)
return params, losses
def main():
key = jr.PRNGKey(2)
key, subkey = jr.split(key)
# Define the model parameters
cycle_dim = 3
num_emissions = 1 # Only one emission at a time
num_observable_states = 2
# Initialize the parameters for the cycling model
initial_probs = jnp.full((num_observable_states), 1.0/num_observable_states)
# transition_matrix = jnp.full((cycle_dim, num_observable_states, num_observable_states), 1.0)
# initial_probs = None
transition_matrix = None
key, subkey = jr.split(key)
transition_matrix = jax.nn.softmax(.2*jr.normal(subkey, shape=(cycle_dim, num_observable_states, num_observable_states)), axis=-1)
# Construct the CyclingCategoricalHMM
hmm = CyclingCategoricalHMM(cycle_dim, num_observable_states, num_emissions,
# initial_probs_concentration=1.1, # # todo: default 1.1
# transition_matrix_concentration=.5, # todo: default 1.1
# transition_matrix_stickiness=0.0 # todo: default 0.0
)
# Initialize the parameters struct with known values
key, subkey = jr.split(key)
params, _ = hmm.initialize(subkey,
initial_probs=initial_probs,
transition_matrix=transition_matrix,
)
# Generate synthetic data
num_batches = 10000
num_timesteps = 100
key, subkey = jr.split(key)
batch_states, batch_emissions = \
jax.vmap(partial(hmm.sample, params, num_timesteps=num_timesteps))(
jr.split(subkey, num_batches))
print(f"batch_states.shape: {batch_states[1].shape}")
print(f"batch_emissions.shape: {batch_emissions[1].shape}")
# note that batch_states[0] and batch_emissions[0] correspond
# to the cycle_dim index, i.e. $$timestep % cycle_dim$$
def print_params(params):
jnp.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
print("Initial probs:")
print(params.initial.probs)
print("Transition matrices:")
for i in range(cycle_dim):
print(f"Transition matrix {i}:")
print(params.transitions.transition_matrix[i])
print('')
print('True Params:')
print_params(params)
# Train the model using EM
num_iters = 20
key, subkey = jr.split(key)
em_params, em_param_props = hmm.initialize(subkey)
em_params, log_probs = hmm.fit_em(em_params,
em_param_props,
batch_emissions,
num_iters=num_iters)
sgd_params, sgd_param_props = hmm.initialize(key)
sgd_key, key = jr.split(key)
sgd_params, sgd_losses = hmm.fit_sgd(sgd_params,
sgd_param_props,
batch_emissions,
optimizer=optax.sgd(learning_rate=1e-2, momentum=0.95),
batch_size=num_batches//10,
num_epochs=num_iters,
key=sgd_key)
print('SGD Params:')
print_params(sgd_params)
# Compute the "losses" from EM
em_losses = -log_probs / batch_emissions[1].size
# Compute the loss if you used the parameters that generated the data
true_loss = jax.vmap(partial(hmm.marginal_log_prob, params))(batch_emissions).sum()
true_loss += hmm.log_prior(params)
true_loss = -true_loss / batch_emissions[1].size
# Plot the learning curve
plt.plot(sgd_losses, label=f"SGD (mini-batch size = {num_batches//10})")
plt.plot(em_losses, label="EM")
plt.axhline(true_loss, color='k', linestyle=':', label="True Params")
plt.legend()
plt.xlim(-2, num_iters)
plt.xlabel("epoch")
plt.ylabel("loss")
_ = plt.title("Learning Curve")
print('EM learned parameters:')
print_params(em_params)
plt.show()
def simple_main():
key = jr.PRNGKey(0)
key, subkey = jr.split(key)
cycle_dim = 3
num_states = 2
num_emissions = 1
num_timesteps = 100 # Simplified for testing
initial_probs = jax.nn.softmax(.2*jr.normal(subkey, (num_states,)))
print("Initial Probabilities:", initial_probs)
print("Sum of Initial Probabilities:", jnp.sum(initial_probs))
key, subkey = jr.split(key)
transition_matrix = jax.nn.softmax(2*jr.normal(subkey, (cycle_dim, num_states, num_states)), axis=-1)
print("Transition Matrices:")
for i in range(cycle_dim):
print(f"Transition matrix {i}:\n{transition_matrix[i]}")
hmm = CyclingCategoricalHMM(cycle_dim, num_states, num_emissions)
key, subkey = jr.split(key)
params, _ = hmm.initialize(subkey, initial_probs=initial_probs, transition_matrix=transition_matrix)
key, subkey = jr.split(key)
states, emissions = hmm.sample(params, num_timesteps=num_timesteps, key=subkey)
print("Sampled States:\n", states)
print("Sampled Emissions:\n", emissions)
(initial_stats, transition_stats, emission_stats), log_likelihood = hmm.e_step(params, emissions)
log_likelihood /= num_timesteps
print("Initial Stats:", initial_stats)
print("Transition Stats:", transition_stats)
print("Emission Stats:", emission_stats)
print("Log Likelihood:", log_likelihood)
if __name__ == '__main__':
main()
# simple_main() output:
|
Sorry to ask a question that probably isn't bug related, but I didn't think I'd get help elsewhere (StackOverflow etc.) It's minorly related to #310.
I have runnable code of the following toy problem, but it doesn't learn.
Here's a summary of the code.
CyclingHMMInitialState
subclassesStandardHMMInitialState
. It implementsdistribution
by returning atfd.JointDistributionSequential
of two values.The first is the integer tracking which of the$M$ transition matrices we're using. Since it's the initial state, we start with
tfd.Deterministic(0)
. The second value istfd.Categorical(probs=params.probs)
based on the HMM casino tutorial.CyclingHMMTransitions
subclassesStandardHMMTransitions
. We have to implementconcentration
andtransition_matrix
to be 3D instead of 2D. We also implementdistribution
to increment the choice of transition matrix modself.cycle_dim
isThe implementation of
collect_suff_stats
is interesting and possibly wrong.Note that$T \times K \times K$ . I think we want this function to return a matrix that's shaped $M \times K \times K$ . If we consider what should be at [0, ..., ...] in this output matrix, it's based on the transitions at timesteps $t$ where $mod(t, M)=0$ . The output at [1, ..., ...] is based on all the transitions where $mod(t, M)=1$ , and so on.
posterior.trans_probs
, which the superclass's method would return, is shapedCyclingHMMEmissions
subclassesCategoricalHMMEmissions
. Its distribution is fully deterministic.Therefore it has no learnable parameters, and so
log_prior
just returns 0. Also it implements this:This output is shaped$T \times K$ . Based on the emissions, the log-likelihood is either 0 (
log(1)
) or-jnp.inf
(limit oflog(0)
). I'm not sure about this function, but I also tried just returningreturn jnp.zeros(shape=(num_timesteps, self.num_states))
wherenum_timesteps = emissions[0].shape[0]
. And that dudn't lead to successful learning.Last,$mod(t, M)$ -th transition matrix.
CyclingCategoricalHMM
subclassesHMM
. Its implementation ofe_step
uses atransition_fn
which selects theHere's the full code:
The output:
The text was updated successfully, but these errors were encountered: