Skip to content

Commit

Permalink
singleview EKS cleanup (#22)
Browse files Browse the repository at this point in the history
* align APIs for singleview and ibl pupil

* remove unnecessary outputs from jax ensemble function

* flake8 cleanup
  • Loading branch information
themattinthehatt authored Dec 11, 2024
1 parent 041a148 commit a9a14c5
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 404 deletions.
58 changes: 30 additions & 28 deletions eks/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from collections import defaultdict

import jax
import jax.scipy as jsc
Expand Down Expand Up @@ -257,24 +256,29 @@ def jax_ensemble(
markers_3d_array: np.ndarray,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
):
) -> tuple:
"""
Computes ensemble median (or mean) and variance of a 3D array of DLC marker data using JAX.
Compute ensemble mean/median and variance of a 3D marker array using JAX.
Args:
markers_3d_array
markers_3d_array: shape (n_models, samples, 3 * n_keypoints); "3" is for x, y, likelihood
avg_mode
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
Returns:
ensemble_preds: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled predictions for each keypoint for each target
ensemble_vars: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled variances for each keypoint for each target
tuple:
ensemble_preds: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled predictions for each keypoint
ensemble_vars: np.ndarray
shape (n_timepoints, n_keypoints, n_coordinates).
ensembled variances for each keypoint
ensemble_likes: np.ndarray
shape (n_timepoints, n_keypoints, 1).
mean likelihood for each keypoint
"""
markers_3d_array = jnp.array(markers_3d_array) # Convert to JAX array
n_frames = markers_3d_array.shape[1]
Expand All @@ -283,6 +287,7 @@ def jax_ensemble(
# Initialize output structures
ensemble_preds = np.zeros((n_frames, n_keypoints, 2))
ensemble_vars = np.zeros((n_frames, n_keypoints, 2))
ensemble_likes = np.zeros((n_frames, n_keypoints, 1))

# Choose the appropriate JAX function based on the mode
if avg_mode == 'median':
Expand All @@ -300,9 +305,10 @@ def compute_stats(i):
avg_x = avg_func(data_x)
avg_y = avg_func(data_y)

conf_per_keypoint = jnp.sum(data_likelihood, axis=0)
mean_conf_per_keypoint = conf_per_keypoint / data_likelihood.shape[0]

if var_mode in ['conf_weighted_var', 'confidence_weighted_var']:
conf_per_keypoint = jnp.sum(data_likelihood, axis=0)
mean_conf_per_keypoint = conf_per_keypoint / data_likelihood.shape[0]
var_x = jnp.nanvar(data_x, axis=0) / mean_conf_per_keypoint
var_y = jnp.nanvar(data_y, axis=0) / mean_conf_per_keypoint
elif var_mode in ['var', 'variance']:
Expand All @@ -311,28 +317,25 @@ def compute_stats(i):
else:
raise ValueError(f"{var_mode} for variance computation not supported")

return avg_x, avg_y, var_x, var_y
return avg_x, avg_y, var_x, var_y, mean_conf_per_keypoint

compute_stats_jit = jax.jit(compute_stats)
stats = jax.vmap(compute_stats_jit)(jnp.arange(n_keypoints))

avg_x, avg_y, var_x, var_y = stats
avg_x, avg_y, var_x, var_y, likes = stats

keypoints_avg_dict = {}
for i in range(n_keypoints):
ensemble_preds[:, i, 0] = avg_x[i]
ensemble_preds[:, i, 1] = avg_y[i]
ensemble_vars[:, i, 0] = var_x[i]
ensemble_vars[:, i, 1] = var_y[i]
keypoints_avg_dict[2 * i] = avg_x[i]
keypoints_avg_dict[2 * i + 1] = avg_y[i]
ensemble_likes[:, i, 0] = likes[i]

# Convert outputs to JAX arrays
ensemble_preds = jnp.array(ensemble_preds)
ensemble_vars = jnp.array(ensemble_vars)
keypoints_avg_dict = {k: jnp.array(v) for k, v in keypoints_avg_dict.items()}

return ensemble_preds, ensemble_vars, keypoints_avg_dict
return ensemble_preds, ensemble_vars, ensemble_likes


def kalman_filter_step(carry, curr_y):
Expand Down Expand Up @@ -720,27 +723,26 @@ def eks_zscore(eks_predictions, ensemble_means, ensemble_vars, min_ensemble_std=


def compute_covariance_matrix(ensemble_preds):
"""
Compute the covariance matrix E for correlated noise dynamics.
"""Compute the covariance matrix E for correlated noise dynamics.
Parameters:
ensemble_preds: A 3D array of shape (T, n_keypoints, n_coords)
containing the ensemble predictions.
Args:
ensemble_preds: shape (T, n_keypoints, n_coords) containing the ensemble predictions.
Returns:
E: A 2K x 2K covariance matrix where K is the number of keypoints.
E: A 2K x 2K covariance matrix where K is the number of keypoints.
"""
# Get the number of time steps, keypoints, and coordinates
T, n_keypoints, n_coords = ensemble_preds.shape

# Flatten the ensemble predictions to shape (T, 2K) where K is the number of keypoints
flattened_preds = ensemble_preds.reshape(T, -1)
# flattened_preds = ensemble_preds.reshape(T, -1)

# Compute the temporal differences
temporal_diffs = np.diff(flattened_preds, axis=0)
# temporal_diffs = np.diff(flattened_preds, axis=0)

# Compute the covariance matrix of the temporal differences
E = np.cov(temporal_diffs, rowvar=False)
# E = np.cov(temporal_diffs, rowvar=False)

# Index covariance matrix into blocks for each keypoint
cov_mats = []
Expand Down
59 changes: 26 additions & 33 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,28 @@ def add_mean_to_array(pred_arr, keys, mean_x, mean_y):
def fit_eks_pupil(
input_source: Union[str, list],
save_file: str,
smooth_params: list,
smooth_params: Optional[list] = None,
s_frames: Optional[list] = None,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
) -> tuple:
"""Function to fit the Ensemble Kalman Smoother for the ibl-pupil dataset.
"""Fit the Ensemble Kalman Smoother for the ibl-pupil dataset.
Args:
input_source: Directory path or list of input CSV files.
save_file: File to save outputs.
smooth_params: List containing diameter_s and com_s.
input_source: directory path or list of CSV file paths. If a directory path, all files
within this directory will be used.
save_file: File to save output dataframe.
smooth_params: [diameter param, center of mass param]
each value should be in (0, 1); closer to 1 means more smoothing
s_frames: Frames for automatic optimization if needed.
avg_mode
avg_mode: mode for averaging across ensemble
'median' | 'mean'
var_mode
'confidence_weighted_var' | 'var'
var_mode: mode for computing ensemble variance
'var' | 'confidence_weighted_var'
Returns:
tuple:
df_smotthed (pd.DataFrame):
df_smoothed (pd.DataFrame)
smooth_params (list): Final smoothing parameters used.
input_dfs_list (list): List of input DataFrames.
keypoint_names (list): List of keypoint names.
Expand All @@ -111,14 +113,13 @@ def fit_eks_pupil(
"""

# Load and format input files
input_dfs_list, output_df, keypoint_names = format_data(input_source)
input_dfs_list, _, keypoint_names = format_data(input_source)

print(f"Input data loaded for keypoints: {keypoint_names}")

# Run the ensemble Kalman smoother
df_smoothed, smooth_params, nll_values = ensemble_kalman_smoother_ibl_pupil(
df_smoothed, smooth_params_final, nll_values = ensemble_kalman_smoother_ibl_pupil(
markers_list=input_dfs_list,
keypoint_names=keypoint_names,
smooth_params=smooth_params,
s_frames=s_frames,
avg_mode=avg_mode,
Expand All @@ -130,13 +131,12 @@ def fit_eks_pupil(
df_smoothed.to_csv(save_file)
print("DataFrames successfully converted to CSV")

return df_smoothed, smooth_params, input_dfs_list, keypoint_names, nll_values
return df_smoothed, smooth_params_final, input_dfs_list, keypoint_names, nll_values


def ensemble_kalman_smoother_ibl_pupil(
markers_list: list,
keypoint_names: list,
smooth_params: list,
smooth_params: Optional[list] = None,
s_frames: Optional[list] = None,
avg_mode: str = 'median',
var_mode: str = 'confidence_weighted_var',
Expand All @@ -147,7 +147,6 @@ def ensemble_kalman_smoother_ibl_pupil(
Args:
markers_list: pd.DataFrames
each list element is a dataframe of predictions from one ensemble member
keypoint_names
smooth_params: contains smoothing parameters for diameter and center of mass
s_frames: frames for automatic optimization if s is not provided
avg_mode
Expand All @@ -165,12 +164,13 @@ def ensemble_kalman_smoother_ibl_pupil(
"""

# compute ensemble median
keys = [
'pupil_top_r_x', 'pupil_top_r_y', 'pupil_bottom_r_x', 'pupil_bottom_r_y',
'pupil_right_r_x', 'pupil_right_r_y', 'pupil_left_r_x', 'pupil_left_r_y',
]
ensemble_preds, ensemble_vars, ensemble_likes, ensemble_stacks = ensemble(
# pupil smoother only works for a pre-specified set of points
# NOTE: this order MUST be kept
keypoint_names = ['pupil_top_r', 'pupil_bottom_r', 'pupil_right_r', 'pupil_left_r']
keys = [f'{kp}_{coord}' for kp in keypoint_names for coord in ['x', 'y']]

# compute ensemble information
ensemble_preds, ensemble_vars, ensemble_likes, _ = ensemble(
markers_list, keys, avg_mode=avg_mode, var_mode=var_mode,
)

Expand Down Expand Up @@ -202,25 +202,18 @@ def ensemble_kalman_smoother_ibl_pupil(
[0, 1, 0], [.5, 0, 1],
[.5, 1, 0], [0, 0, 1],
[-.5, 1, 0], [0, 0, 1]
])
])

# placeholder diagonal matrix for ensemble variance
R = np.eye(8)

scaled_ensemble_preds = ensemble_preds.copy()
scaled_ensemble_stacks = ensemble_stacks.copy()
# subtract COM means from the ensemble predictions
for i in range(ensemble_preds.shape[1]):
if i % 2 == 0:
scaled_ensemble_preds[:, i] -= mean_x_obs
else:
scaled_ensemble_preds[:, i] -= mean_y_obs
# subtract COM means from all the predictions
for i in range(ensemble_preds.shape[1]):
if i % 2 == 0:
scaled_ensemble_stacks[:, :, i] -= mean_x_obs
else:
scaled_ensemble_stacks[:, :, i] -= mean_y_obs
y_obs = scaled_ensemble_preds

# --------------------------------------
Expand Down Expand Up @@ -304,12 +297,12 @@ def ensemble_kalman_smoother_ibl_pupil(

def pupil_optimize_smooth(
y, m0, S0, C, R, ensemble_vars, diameters_var, x_var, y_var,
s_frames=[(1, 2000)],
smooth_params=[None, None],
s_frames: Optional[list] = [(1, 2000)],
smooth_params: Optional[list] = [None, None],
):
"""Optimize-and-smooth function for the pupil example script."""
# Optimize smooth_param
if smooth_params[0] is None or smooth_params[1] is None:
if smooth_params is None or smooth_params[0] is None or smooth_params[1] is None:

# Unpack s_frames
y_shortened = crop_frames(y, s_frames)
Expand Down
20 changes: 14 additions & 6 deletions eks/multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
import pandas as pd
from scipy.optimize import minimize

from eks.core import ensemble, eks_zscore, compute_initial_guesses, forward_pass, backward_pass, \
compute_nll
from eks.ibl_paw_multiview_smoother import remove_camera_means, pca
from eks.utils import make_dlc_pandas_index, crop_frames
from eks.core import (
backward_pass,
compute_initial_guesses,
compute_nll,
eks_zscore,
ensemble,
forward_pass,
)
from eks.ibl_paw_multiview_smoother import pca, remove_camera_means
from eks.utils import crop_frames, make_dlc_pandas_index


def ensemble_kalman_smoother_multicam(
Expand Down Expand Up @@ -158,8 +164,10 @@ def ensemble_kalman_smoother_multicam(
# --------------------------------------
# final cleanup
# --------------------------------------
pdindex = make_dlc_pandas_index([keypoint_ensemble],
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore", "nll", "ensemble_std"])
pdindex = make_dlc_pandas_index(
[keypoint_ensemble],
labels=["x", "y", "likelihood", "x_var", "y_var", "zscore", "nll", "ensemble_std"]
)
camera_indices = []
for camera in range(num_cameras):
camera_indices.append([camera * 2, camera * 2 + 1])
Expand Down
Loading

0 comments on commit a9a14c5

Please sign in to comment.