Skip to content

Commit

Permalink
Merge pull request #340 from jakevdp/row_stack
Browse files Browse the repository at this point in the history
Replace jnp.row_stack with jnp.vstack
  • Loading branch information
murphyk authored Aug 24, 2023
2 parents b8ce43e + d7354f3 commit 95a09ba
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ def _step(carry, args):
_, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args)

# Reverse the arrays and return
smoothed_means = jnp.row_stack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.row_stack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _step(carry, args):
_, rev_smoothed_probs = lax.scan(_step, carry, args)

# Reverse the arrays and return
smoothed_probs = jnp.row_stack([rev_smoothed_probs[::-1], filtered_probs[-1]])
smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]])

# Package into a posterior
posterior = HMMPosterior(
Expand Down
2 changes: 1 addition & 1 deletion dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_hmm_filter(key=0, num_timesteps=3, num_states=2):
# Compare predicted_probs to manually computed entries
for t in range(num_timesteps):
log_joint_t = big_log_joint(initial_probs, transition_matrix,
jnp.row_stack([log_lkhds[:t], jnp.zeros(num_states)]))
jnp.vstack([log_lkhds[:t], jnp.zeros(num_states)]))

log_joint_t -= logsumexp(log_joint_t)
predicted_probs_t = jnp.exp(logsumexp(log_joint_t, axis=tuple(jnp.arange(t))))
Expand Down
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/models/arhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,14 @@ def _step(carry, key):
key1, key2 = jr.split(key, 2)
state = self.transition_distribution(params, prev_state).sample(seed=key2)
emission = self.emission_distribution(params, state, inputs=jnp.ravel(prev_emissions)).sample(seed=key1)
next_prev_emissions = jnp.row_stack([emission, prev_emissions[:-1]])
next_prev_emissions = jnp.vstack([emission, prev_emissions[:-1]])
return (state, next_prev_emissions), (state, emission)

# Sample the initial state
key1, key2, key = jr.split(key, 3)
initial_state = self.initial_distribution(params).sample(seed=key1)
initial_emission = self.emission_distribution(params, initial_state, inputs=jnp.ravel(prev_emissions)).sample(seed=key2)
initial_prev_emissions = jnp.row_stack([initial_emission, prev_emissions[:-1]])
initial_prev_emissions = jnp.vstack([initial_emission, prev_emissions[:-1]])

# Sample the remaining emissions and states
next_keys = jr.split(key, num_timesteps - 1)
Expand Down
6 changes: 3 additions & 3 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ def _step(carry, args):
_, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(_step, init_carry, args)

# Reverse the arrays and return
smoothed_means = jnp.row_stack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.row_stack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_cross = smoothed_cross[::-1]
return PosteriorGSSMSmoothed(
marginal_loglik=ll,
Expand Down Expand Up @@ -610,5 +610,5 @@ def _step(carry, args):
jnp.arange(num_timesteps - 2, -1, -1),
)
_, reversed_states = lax.scan(_step, last_state, args)
states = jnp.row_stack([reversed_states[::-1], last_state])
states = jnp.vstack([reversed_states[::-1], last_state])
return states
4 changes: 2 additions & 2 deletions dynamax/linear_gaussian_ssm/info_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def _smooth_step(carry, args):
_, (smoothed_etas, smoothed_precisions) = lax.scan(_smooth_step, init_carry, args)

# Reverse the arrays and return
smoothed_etas = jnp.row_stack((smoothed_etas[::-1], filtered_etas[-1][None, ...]))
smoothed_precisions = jnp.row_stack((smoothed_precisions[::-1], filtered_precisions[-1][None, ...]))
smoothed_etas = jnp.vstack((smoothed_etas[::-1], filtered_etas[-1][None, ...]))
smoothed_precisions = jnp.vstack((smoothed_precisions[::-1], filtered_precisions[-1][None, ...]))
return PosteriorGSSMInfoSmoothed(
marginal_loglik=ll,
filtered_etas=filtered_etas,
Expand Down
6 changes: 3 additions & 3 deletions dynamax/nonlinear_gaussian_ssm/inference_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def _step(carry, args):
_, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args)

# Reverse the arrays and return
smoothed_means = jnp.row_stack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.row_stack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand Down Expand Up @@ -311,7 +311,7 @@ def _step(carry, args):
jnp.arange(num_timesteps - 2, -1, -1),
)
_, reversed_states = lax.scan(_step, last_state, args)
states = jnp.row_stack([reversed_states[::-1], last_state])
states = jnp.vstack([reversed_states[::-1], last_state])
return states


Expand Down
4 changes: 2 additions & 2 deletions dynamax/nonlinear_gaussian_ssm/inference_ukf.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def _step(carry, args):
_, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args)

# Reverse the arrays and return
smoothed_means = jnp.row_stack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.row_stack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand Down

0 comments on commit 95a09ba

Please sign in to comment.