Skip to content

Commit

Permalink
updating orientation distribution plots
Browse files Browse the repository at this point in the history
  • Loading branch information
MJoosten committed Mar 14, 2024
1 parent 0b588f7 commit 38ed731
Show file tree
Hide file tree
Showing 10 changed files with 1,353 additions and 1,337 deletions.
49 changes: 27 additions & 22 deletions src/roodmus/analysis/plot_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,10 @@ def true_pose_distribution_plot(
vmax: float | None = None,
):
df_truth["euler_phi"] = df_truth["euler_phi"].astype(float)
df_truth["euler_theta"] = -(
df_truth["euler_theta"].astype(float) - np.pi / 2
)
df_truth["euler_theta"] = df_truth["euler_theta"].astype(float)
# df_truth["euler_theta"] = -(
# df_truth["euler_theta"].astype(float) - np.pi / 2
# )
df_truth["euler_psi"] = df_truth["euler_psi"].astype(float)

grid = sns.jointplot(
Expand Down Expand Up @@ -194,7 +195,7 @@ def true_pose_distribution_plot(
)
grid.ax_joint.set_xticklabels(
[
"\u03C0",
"-\u03C0",
"-3/4\u03C0",
"-\u03C0/2",
"-\u03C0/4",
Expand All @@ -205,12 +206,12 @@ def true_pose_distribution_plot(
"\u03C0",
]
)
grid.ax_joint.set_yticks([-np.pi / 2, -np.pi / 4, 0, np.pi / 4, np.pi / 2])
grid.ax_joint.set_yticks([0, np.pi / 4, np.pi / 2, 3 / 4 * np.pi, np.pi])
grid.ax_joint.set_yticklabels(
["\u03C0/2", "\u03C0/4", "0", "\u03C0/4", "\u03C0/2"]
["0", "\u03C0/4", "\u03C0/2", "3/4\u03C0", "\u03C0"]
)
grid.ax_joint.set_xlabel("Azimuth")
grid.ax_joint.set_ylabel("Elevation")
grid.ax_joint.set_ylabel("Tilt")
# add new sublot to the right of the jointplot
cbar_ax = grid.fig.add_axes([1, 0.15, 0.02, 0.7])
# add colorbar to the new subplot
Expand Down Expand Up @@ -275,7 +276,11 @@ def make_and_save_plots(
"metadata_filename"
].unique()
):
self.grid, self.vmin, self.vmax = picked_pose_distribution(
(
self.grid,
self.vmin,
self.vmax,
) = picked_pose_distribution_plot(
self.plot_data["plot_picked_pose_distribution"][
"df_picked"
],
Expand Down Expand Up @@ -310,7 +315,7 @@ def _save_plot(self, meta_file: str):
)


def picked_pose_distribution(
def picked_pose_distribution_plot(
df_picked: pd.DataFrame,
metadata_filename: str | List[str],
vmin: float | None = None,
Expand All @@ -324,15 +329,15 @@ def picked_pose_distribution(
)

# change data type of column euler_phi to float
df_picked_grouped["euler_phi"] = df_picked_grouped["euler_phi"].astype(
float
)
df_picked_grouped["euler_theta"] = -(
df_picked_grouped["euler_theta"].astype(float) - np.pi / 2
)
df_picked_grouped["euler_psi"] = df_picked_grouped["euler_psi"].astype(
float
)
# df_picked_grouped["euler_phi"] = df_picked_grouped["euler_phi"].astype(
# float
# )
# df_picked_grouped["euler_theta"] = -(
# df_picked_grouped["euler_theta"].astype(float) - np.pi / 2
# )
# df_picked_grouped["euler_psi"] = df_picked_grouped["euler_psi"].astype(
# float
# )

# plot the alignment
grid = sns.jointplot(
Expand Down Expand Up @@ -363,7 +368,7 @@ def picked_pose_distribution(
)
grid.ax_joint.set_xticklabels(
[
"\u03C0",
"-\u03C0",
"-3/4\u03C0",
"-\u03C0/2",
"-\u03C0/4",
Expand All @@ -374,12 +379,12 @@ def picked_pose_distribution(
"\u03C0",
]
)
grid.ax_joint.set_yticks([-np.pi / 2, -np.pi / 4, 0, np.pi / 4, np.pi / 2])
grid.ax_joint.set_yticks([0, np.pi / 4, np.pi / 2, 3 / 4 * np.pi, np.pi])
grid.ax_joint.set_yticklabels(
["-\u03C0/2", "-\u03C0/4", "0", "\u03C0/4", "\u03C0/2"]
["0", "\u03C0/4", "\u03C0/2", "3/4\u03C0", "\u03C0"]
)
grid.ax_joint.set_xlabel("Azimuth")
grid.ax_joint.set_ylabel("Elevation")
grid.ax_joint.set_ylabel("Tilt")
# add new sublot to the right of the jointplot
cbar_ax = grid.fig.add_axes([1, 0.15, 0.02, 0.7])
# add colorbar to the new subplot
Expand Down
12 changes: 6 additions & 6 deletions src/roodmus/analysis/plot_picking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,13 +2215,13 @@ def plot_precision(
sm = plt.cm.ScalarMappable(
cmap="RdYlBu",
norm=plt.Normalize(
vmin=df_precision["defocus"].min(),
vmax=df_precision["defocus"].max(),
vmin=df_precision["defocus"].min() / 10000,
vmax=df_precision["defocus"].max() / 10000,
),
)
sm._A = []
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label("defocus (Å)", rotation=270, labelpad=20, fontsize=12)
cbar.set_label("defocus (\u03bcm)", rotation=270, labelpad=20, fontsize=12)
# add labels
ax.set_xlabel("")
ax.set_ylabel("precision", fontsize=14)
Expand Down Expand Up @@ -2283,13 +2283,13 @@ def plot_recall(
sm = plt.cm.ScalarMappable(
cmap="RdYlBu",
norm=plt.Normalize(
vmin=df_precision["defocus"].min(),
vmax=df_precision["defocus"].max(),
vmin=df_precision["defocus"].min() / 10000,
vmax=df_precision["defocus"].max() / 10000,
),
)
sm._A = []
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label("defocus (Å)", rotation=270, labelpad=20, fontsize=12)
cbar.set_label("defocus (\u03bcm)", rotation=270, labelpad=20, fontsize=12)
# add labels
ax.set_xlabel("")
ax.set_ylabel("recall", fontsize=14)
Expand Down
57 changes: 34 additions & 23 deletions src/roodmus/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import yaml
import numpy as np
from scipy.spatial import cKDTree
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm
import pandas as pd
import pickle
Expand Down Expand Up @@ -59,26 +60,25 @@ def load_cs(self, cs_path):

@classmethod
def get_ugraph_cs(self, metadata_cs: np.recarray) -> List[str]:
"""Grab micrograph file paths from .cs data.
"""Grab micrograph file paths from .cs data. if micrograph file
paths are not present in the metadata, return an empty list.
Args:
metadata_cs (np.recarray): .cs file metadata.
Returns:
ugraph_paths (List[str]): micrograph file paths.
"""
ugraph_paths = metadata_cs["location/micrograph_path"].tolist()

# elif "blob/path" in metadata_cs.dtype.names:
# ugraph_paths = metadata_cs["blob/path"]
# removed None from within the function to stop mypy errors
# else:
# return None
if "location/micrograph_path" in metadata_cs.dtype.names:
ugraph_paths = metadata_cs["location/micrograph_path"].tolist()
ugraph_paths = [
os.path.basename(path).decode("utf-8").split("_")[-1]
for path in ugraph_paths
]
else:
ugraph_paths = []

ugraph_paths = [
os.path.basename(path).decode("utf-8").split("_")[-1]
for path in ugraph_paths
]
return ugraph_paths

@classmethod
Expand Down Expand Up @@ -164,13 +164,13 @@ def get_orientations_cs(self, metadata_cs: np.recarray, return_pose=False):
"alignments3D/pose"
] # orientations as rodriques vectors
# convert to euler angles
euler = np.array(
[geom.rot2euler(geom.expmap(p)) for p in pose],
dtype=float,
)
# euler = R.from_rotvec(pose).as_euler(
# "zyx", degrees=False
# ) # convert to euler angles
# euler = np.array(
# [geom.rot2euler(geom.expmap(p)) for p in pose],
# dtype=float,
# )
euler = R.from_rotvec(pose).as_euler(
"ZYZ", degrees=False
) # convert to euler angles
if return_pose:
return euler, pose
else:
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_class2D_cs(self, metadata_cs: np.recarray):
np.ndarray: 2D class.
"""
if "alignments2D/class" in metadata_cs.dtype.names:
class2d = metadata_cs["alignments2D/class"]
class2d = metadata_cs["alignments2D/class"].astype(int)
else:
class2d = None
return class2d
Expand Down Expand Up @@ -359,10 +359,16 @@ def get_orientations_star(self, metadata_star) -> np.ndarray:
)
if not euler_phi: # if empty
euler_phi = [np.nan] * num_particles
else:
euler_phi = [np.deg2rad(float(r)) for r in euler_phi]
if not euler_theta:
euler_theta = [np.nan] * num_particles
else:
euler_theta = [np.deg2rad(float(r)) for r in euler_theta]
if not euler_psi:
euler_psi = [np.nan] * num_particles
else:
euler_psi = [np.deg2rad(float(r)) for r in euler_psi]
euler = np.stack([euler_phi, euler_theta, euler_psi], axis=1)
return euler

Expand All @@ -376,7 +382,12 @@ def get_class2D_star(self, metadata_star):
Returns:
class2d (np.ndarray): 2D class data.
"""
class2d = metadata_star.column_as_list("particles", "_rlnClassNumber")
class2d = [
int(r)
for r in metadata_star.column_as_list(
"particles", "_rlnClassNumber"
)
]
if class2d:
return np.array(class2d)
else:
Expand Down Expand Up @@ -1017,7 +1028,7 @@ def _extract_from_metadata(
ugraph_filename = IO.get_ugraph_cs(m)
uid = IO.get_uid_cs(m)
tmp_results["uid"].extend(uid)
if ugraph_filename is None:
if not ugraph_filename:
num_particles = len(uid)
tmp_results["ugraph_filename"].extend(
[np.nan] * num_particles
Expand Down Expand Up @@ -1402,8 +1413,8 @@ def _extract_from_config(
position = instance["position"]
orientation = instance["orientation"] # rotation vector
# convert to euler angles
euler = geom.rot2euler(geom.expmap(np.array(orientation)))
# euler = R.from_rotvec(orientation).as_euler("zyx")
# euler = geom.rot2euler(geom.expmap(np.array(orientation)))
euler = R.from_rotvec(orientation).as_euler("ZYZ")
if return_pose:
orientations_list.append(orientation)
else:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 38ed731

Please sign in to comment.