Skip to content

Commit

Permalink
Add most likely sequence function
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Jun 6, 2024
1 parent aec03d9 commit db7c5be
Show file tree
Hide file tree
Showing 2 changed files with 380 additions and 2 deletions.
171 changes: 171 additions & 0 deletions src/non_local_detector/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,86 @@ def chunked_filter_smoother(
)


@jax.jit
def viterbi(
initial_distribution: jnp.ndarray,
transition_matrix: jnp.ndarray,
log_likelihoods: jnp.ndarray,
) -> jnp.ndarray:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.
Parameters
----------
initial_distribution : jnp.ndarray, shape (n_states,)
Initial state distribution
transition_matrix : jnp.ndarray, shape (n_states, n_states)
Transition matrix
log_likelihoods : jnp.ndarray, shape (n_time, n_states)
Log likelihoods for each state at each time point
Returns
-------
most_likely_state_sequence : jnp.ndarray, shape (n_time,)
"""

# Run the backward pass
def _backward_pass(best_next_score, t):
scores = jnp.log(transition_matrix) + best_next_score + log_likelihoods[t + 1]
best_next_state = jnp.argmax(scores, axis=1)
best_next_score = jnp.max(scores, axis=1)
return best_next_score, best_next_state

num_timesteps, num_states = log_likelihoods.shape
best_second_score, best_next_states = jax.lax.scan(
_backward_pass,
jnp.zeros(num_states),
jnp.arange(num_timesteps - 1),
reverse=True,
)

# Run the forward pass
def _forward_pass(state, best_next_state):
next_state = best_next_state[state]
return next_state, next_state

first_state = jnp.argmax(
jnp.log(initial_distribution) + log_likelihoods[0] + best_second_score
)
_, states = jax.lax.scan(_forward_pass, first_state, best_next_states)

return jnp.concatenate([jnp.array([first_state]), states])


def most_likely_sequence(
time: np.ndarray,
initial_distribution: np.ndarray,
transition_matrix: np.ndarray,
log_likelihood_func: callable,
log_likelihood_args: tuple,
is_missing: Optional[np.ndarray] = None,
log_likelihoods: Optional[np.ndarray] = None,
n_chunks: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:

if n_chunks > 1:
raise NotImplementedError("Chunked Viterbi is not yet implemented.")

if log_likelihoods is None:
log_likelihoods = (
log_likelihood_func(
time,
*log_likelihood_args,
is_missing=is_missing,
),
)
return viterbi(
initial_distribution=initial_distribution,
transition_matrix=transition_matrix,
log_likelihoods=log_likelihoods,
)


## Covariate dependent filtering and smoothing ##
def _get_transition_matrix(
discrete_transition_matrix_t: jnp.ndarray,
Expand Down Expand Up @@ -570,6 +650,97 @@ def chunked_filter_smoother_covariate_dependent(
)


@jax.jit
def viterbi_covariate_dependent(
initial_distribution: jnp.ndarray,
discrete_transition_matrix: jnp.ndarray,
continuous_transition_matrix: jnp.ndarray,
state_ind: jnp.ndarray,
log_likelihoods: jnp.ndarray,
) -> jnp.ndarray:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.
Parameters
----------
initial_distribution : jnp.ndarray, shape (n_states,)
Initial state distribution
transition_matrix : jnp.ndarray, shape (n_states, n_states)
Transition matrix
log_likelihoods : jnp.ndarray, shape (n_time, n_states)
Log likelihoods for each state at each time point
Returns
-------
most_likely_state_sequence : jnp.ndarray, shape (n_time,)
"""

# Run the backward pass
def _backward_pass(best_next_score, args):
t, discrete_transition_matrix_t = args
transition_matrix = _get_transition_matrix(
discrete_transition_matrix_t,
continuous_transition_matrix,
state_ind,
)
scores = jnp.log(transition_matrix) + best_next_score + log_likelihoods[t + 1]
best_next_state = jnp.argmax(scores, axis=1)
best_next_score = jnp.max(scores, axis=1)
return best_next_score, best_next_state

num_timesteps, num_states = log_likelihoods.shape
best_second_score, best_next_states = jax.lax.scan(
_backward_pass,
jnp.zeros(num_states),
(jnp.arange(num_timesteps - 1), discrete_transition_matrix[:-1]),
reverse=True,
)

# Run the forward pass
def _forward_pass(state, best_next_state):
next_state = best_next_state[state]
return next_state, next_state

first_state = jnp.argmax(
jnp.log(initial_distribution) + log_likelihoods[0] + best_second_score
)
_, states = jax.lax.scan(_forward_pass, first_state, best_next_states)

return jnp.concatenate([jnp.array([first_state]), states])


def most_likely_sequence_covariate_dependent(
time: np.ndarray,
state_ind: np.ndarray,
initial_distribution: np.ndarray,
discrete_transition_matrix: np.ndarray,
continuous_transition_matrix: np.ndarray,
log_likelihood_func: callable,
log_likelihood_args: tuple,
is_missing: Optional[np.ndarray] = None,
log_likelihoods: Optional[np.ndarray] = None,
n_chunks: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
if n_chunks > 1:
raise NotImplementedError("Chunked Viterbi is not yet implemented.")

if log_likelihoods is None:
log_likelihoods = (
log_likelihood_func(
time,
*log_likelihood_args,
is_missing=is_missing,
),
)
return viterbi_covariate_dependent(
initial_distribution=initial_distribution,
discrete_transition_matrix=discrete_transition_matrix,
continuous_transition_matrix=continuous_transition_matrix,
state_ind=state_ind,
log_likelihoods=log_likelihoods,
)


## Convergence check ##
def check_converged(
log_likelihood: np.ndarray,
Expand Down
Loading

0 comments on commit db7c5be

Please sign in to comment.