Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gibbs Sampling Notebook + Code #2

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
236 changes: 169 additions & 67 deletions dynamax/slds_eqx/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import jax.tree as tree
import operator
import optax

import equinox as eqx
import blackjax
from typing import Tuple, Callable
from jax import grad, lax, vmap
from jaxtyping import Array, Float
from tensorflow_probability.substrates import jax as tfp
from typing import Callable, Literal, Optional
tfd = tfp.distributions
tfb = tfp.bijectors
MVN = tfd.MultivariateNormalFullCovariance
Expand All @@ -19,24 +22,41 @@

def fit_gibbs(slds : SLDS,
key : jr.PRNGKey,
emissions : Float[Array["num_timesteps emission_dim"]], #array of floats of dim num_timesteps x emission_dim
initial_zs : Float[Array["num_timesteps"]],
initial_xs : Float[Array["num_timesteps latent_dim"]],
num_iters : int = 100
emissions : jnp.ndarray,
initial_zs : jnp.ndarray,
initial_xs :jnp.ndarray,
num_iters : int = 100,
lr : float = 1e-3,
reg_schedule : Callable[[int], float] = lambda t: 1.0,
param_update_method: Literal["gradient", "hmc"] = "gradient",
param_update_iters : int = 10,
hmc_num_samples : int = 100,
hmc_num_warmup : int = 100,
hmc_step_size : float = 1e-3,
hmc_num_integration_steps : int = 10,
frozen_params: Optional[List[str]] = None
):
#TODO: Look at jax.lax.stop_gradient
"""
Run a Gibbs sampler to draw (approximate) samples from the posterior distribution over
discrete and continuous latent states of an SLDS.
"""
K = slds.num_states
D = slds.latent_dim
N = slds.emission_dim
ys = emissions # num_timesteps x emission_dim
ys = emissions

if frozen_params is None:
frozen_params = []

def param_filter(param):
return not any(param.name.startswith(fp) for fp in frozen_params)

#theta denotes parameters of the model
#z denotes discrete latent states
#x denotes continuous latent states
#y denotes emissions
if param_update_method == "gradient":
optimizer = optax.adam(lr)
opt_state = optimizer.init(eqx.filter(slds, param_filter))
else:
opt_state = None

def _update_discrete_states(slds, key1, xs):
"""
Expand All @@ -53,20 +73,17 @@ def _update_discrete_states(slds, key1, xs):

# log p(x_t | x_{t-1}, z_t=k) for all t=2,...,T and all k=1,...,K
f = lambda z: vmap(lambda x, xn: slds.dynamics_distn(z, x).log_prob(xn))(xs[:-1], xs[1:]) # [K] -> (T-1,)
lls = vmap(f, jnp.arange(K)).T # (T-1,K)
lls = vmap(f)(jnp.arange(K)).T # (T-1,K)

# Stack the initial log prob and subsequent log probs into one array
lls = jnp.vstack([ll0, lls])
return hmm.inference.hmm_posterior_sample(key1, pi0, P, lls)

return hmm.inference.hmm_posterior_sample(key1, pi0, P, lls)[1]

def _update_continuous_states(slds, key2, ys, zs):
# TODO: sample from the p(x | z, y) by using lgssm_posterior_sample
# and giving the function time-varying parameters A_t = A_{z_t}
"""
Update the continuous states by drawing a sample from p(x | z, y)
"""
T = ys.shape[0] #number of timesteps

# Initialize time-varying parameters
As = slds.dynamics_matrices
Expand All @@ -75,40 +92,33 @@ def _update_continuous_states(slds, key2, ys, zs):
C = slds.emission_matrix
d = slds.emission_bias
R = slds.emission_cov
initial_mean = jnp.zeros(D)
initial_cov = jnp.eye(D)

# Compute parameters for each time step using the discrete states
A_t = vmap(lambda z: As[z])(zs)
b_t = vmap(lambda z: bs[z])(zs)
Q_t = vmap(lambda z: Qs[z])(zs)

# Create ParamsLGSSM object to pass into lgssm_posterior_sample
params = lgssm.inference.ParamsLGSSM(
initial= lgssm.inference.ParamsLGSSMInitial(
mean=jnp.zeros(D),
cov=jnp.eye(D) #identity matrix - assumes initial latent dimensions are uncorrelated
),
dynamics=lgssm.inference.ParamsLGSSMDynamics(
#ntime x state_dim x state_dim
weights=A_t,
bias=b_t,
input_weights=None,
cov=Q_t
),
emissions=lgssm.inference.ParamsLGSSMEmissions(
#broadcasting C, d, R to have shape (T, N, D), (T, N), (T, N, N)
weights=jnp.repeat(C[None, :, :], T, axis=0),
bias=jnp.repeat(d[None, :], T, axis=0),
input_weights=None,
cov=jnp.repeat(R[None, :, :], T, axis=0)
),
A_t = As[zs]
b_t = bs[zs]
Q_t = Qs[zs]

params = lgssm.inference.make_lgssm_params(
initial_mean=initial_mean,
initial_cov=initial_cov,
dynamics_weights=A_t,
dynamics_cov=Q_t,
emissions_weights=C,
emissions_cov=R,
dynamics_bias=b_t,
dynamics_input_weights=None,
emissions_bias=d,
emissions_input_weights=None,
)

# Sample from the posterior distribution
xs = lgssm.inference.lgssm_posterior_sample(key2, params, ys)

return xs

def _update_params(slds, ys, zs, xs, lr=1e-3, reg=1.0, num_iters=10):
def _update_params_gradient(slds, ys, zs, xs, opt_state, reg=1.0, num_iters=10):
r"""
Goal: maximize the expected log probability as a function of parameters \theta:
L(\theta) = E_{p(z, x | y, \theta')}[log p(y, z, x; \theta)]
Expand All @@ -133,37 +143,46 @@ def _update_params(slds, ys, zs, xs, lr=1e-3, reg=1.0, num_iters=10):

The final objective combines these two terms.
"""

T = ys.shape[0]
def loss(curr_slds):
L = -1 * curr_slds.log_prob(ys, zs, xs) / T
L += 0.5 * reg * tree.reduce(
operator.add,
tree.map(lambda x, y: jnp.sum((x - y)**2), curr_slds, slds),
tree.map(lambda x, y: jnp.sum((x - y)**2),
eqx.filter(curr_slds, param_filter),
eqx.filter(slds, param_filter)),
0.0)
return L

# Minimize the loss with optax
# TODO: replace for loop with a scan
optimizer = optax.adam(lr)
opt_state = optimizer.init(slds)
for _ in range(num_iters):
grads = grad(loss)(slds)
updates, opt_state = optimizer.update(grads, opt_state)
slds = optax.apply_updates(slds, updates)
return slds

def _step(carry, step_size): #not using step_size here (num_iters is used instead)
# TODO
# 1. call _update_discrete_states
# 2. call _update_continuous_states
# 3. compute the log joint probability (using slds.log_prob)
# 4. return new_carry and output lp

# Define a single step of the optimization
@eqx.filter_jit
def step(carry, _):
curr_slds, opt_state = carry
grads = eqx.filter_grad(loss)(curr_slds)
updates, new_opt_state = optimizer.update(grads, opt_state)
new_slds = eqx.apply_updates(curr_slds, updates, where=param_filter)
return (new_slds, new_opt_state), None

# Run the optimization using lax.scan
(final_slds, final_opt_state), _ = lax.scan(step, (slds, opt_state), None, length=num_iters)

return final_slds, final_opt_state

def _update_params_hmc(slds, ys, zs, xs, key, reg=1.0):
return update_parameters_hmc(slds, ys, zs, xs, key,
num_samples=hmc_num_samples,
num_warmup=hmc_num_warmup,
step_size=hmc_step_size,
num_integration_steps=hmc_num_integration_steps,
reg=reg)

def _step(carry, t):
# Unpack Carry
zs, xs, slds, key = carry
zs, xs, slds, opt_state, key = carry

# Update Key to generate new random samples
key, subkey1, subkey2 = jr.split(key, 3)
key, subkey1, subkey2, subkey3 = jr.split(key, 4)

# Update Discrete States p(z₁:ₜ | x₁:ₜ, θ)
zs = _update_discrete_states(slds, subkey1, xs)
Expand All @@ -174,19 +193,102 @@ def _step(carry, step_size): #not using step_size here (num_iters is used instea
# Compute Log Joint Probability log p(y₁:ₜ, z₁:ₜ, x₁:ₜ | θ)
lp = slds.log_prob(ys, zs, xs)

# Compute regularization strength for this iteration
reg = reg_schedule(t)

# Update Parameters
slds = _update_params(slds, ys, zs, xs)
if param_update_method == "gradient":
slds, opt_state = _update_params_gradient(slds, ys, zs, xs, opt_state, reg, param_update_iters)
else: # HMC
slds, _ = _update_params_hmc(slds, ys, zs, xs, subkey3, reg)

# Return New Carry and Output Log Probability
new_carry = (zs, xs, slds, key)
new_carry = (zs, xs, slds, opt_state, key)

return new_carry, lp

# TODO: initialize carry and call scan
initial_carry = (initial_zs, initial_xs, slds, key)
final_carry, lps = lax.scan(_step, initial_carry, jnp.arange(num_iters)) #step_size is num_iters
initial_carry = (initial_zs, initial_xs, slds, opt_state, key)
final_carry, lps = lax.scan(_step, initial_carry, jnp.arange(num_iters))

# Unpack Final Carry
zs, xs, slds, key = final_carry
zs, xs, slds, _, _ = final_carry

return slds, lps, zs, xs

def update_parameters_hmc(slds: SLDS,
ys: jnp.ndarray,
zs: jnp.ndarray,
xs: jnp.ndarray,
key: jr.PRNGKey,
num_samples: int = 100,
num_warmup: int = 100,
step_size: float = 1e-3,
num_integration_steps: int = 10,
reg: float = 1.0) -> Tuple[SLDS, jnp.ndarray]:
"""
Update SLDS parameters using Hamiltonian Monte Carlo.

This function can be used as an alternative to _update_params in the Gibbs sampling loop.
"""
def log_posterior(params):
# Unpack parameters
pi0, transition_matrix, dynamics_matrices, dynamics_biases, dynamics_covs, emission_matrix, emission_bias, emission_cov = params

# Create a temporary SLDS with the new parameters
temp_slds = SLDS(slds.num_states, slds.latent_dim, slds.emission_dim)
temp_slds.pi0 = pi0
temp_slds.transition_matrix = transition_matrix
temp_slds.dynamics_matrices = dynamics_matrices
temp_slds.dynamics_biases = dynamics_biases
temp_slds.dynamics_covs = dynamics_covs
temp_slds.emission_matrix = emission_matrix
temp_slds.emission_bias = emission_bias
temp_slds.emission_cov = emission_cov

# Compute log probability
log_prob = temp_slds.log_prob(ys, zs, xs)

# Add regularization term
reg_term = 0.5 * reg * sum(jnp.sum((p1 - p2)**2) for p1, p2 in zip(params, (
slds.pi0, slds.transition_matrix, slds.dynamics_matrices, slds.dynamics_biases,
slds.dynamics_covs, slds.emission_matrix, slds.emission_bias, slds.emission_cov
)))

return log_prob - reg_term

# Pack current parameters
initial_params = (
slds.pi0,
slds.transition_matrix,
slds.dynamics_matrices,
slds.dynamics_biases,
slds.dynamics_covs,
slds.emission_matrix,
slds.emission_bias,
slds.emission_cov
)

# Set up HMC
hmc = blackjax.hmc(log_posterior, step_size, jnp.ones_like(initial_params), num_integration_steps)
state = hmc.init(initial_params)

# Run HMC
@jax.jit
def one_step(state, key):
state, _ = hmc.step(key, state)
return state, state.position

keys = jr.split(key, num_samples + num_warmup)
_, samples = jax.lax.scan(one_step, state, keys)

# Discard warmup samples
samples = samples[num_warmup:]

# Update SLDS with mean of samples
mean_params = jax.tree_map(lambda x: jnp.mean(x, axis=0), samples)

slds.pi0, slds.transition_matrix, slds.dynamics_matrices, slds.dynamics_biases, slds.dynamics_covs, \
slds.emission_matrix, slds.emission_bias, slds.emission_cov = mean_params

return slds, samples

Loading