Skip to content

Commit

Permalink
fixes to plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MJoosten committed Jun 27, 2024
1 parent 37a9685 commit 11b542b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
26 changes: 17 additions & 9 deletions src/roodmus/analysis/plot_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def true_pose_distribution_plot(
cmap="RdYlBu_r",
marginal_kws=dict(bins=100, fill=False),
)
grid.fig.set_size_inches(14, 7)
grid.figure.set_size_inches(7, 3.5)
# adjust the x and y ticks to show multiples of pi
grid.ax_joint.set_xticks(
[
Expand All @@ -204,7 +204,8 @@ def true_pose_distribution_plot(
"\u03C0/2",
"3/4\u03C0",
"\u03C0",
]
],
rotation=45,
)
grid.ax_joint.set_yticks([0, np.pi / 4, np.pi / 2, 3 / 4 * np.pi, np.pi])
grid.ax_joint.set_yticklabels(
Expand All @@ -215,7 +216,9 @@ def true_pose_distribution_plot(
# add new sublot to the right of the jointplot
cbar_ax = grid.fig.add_axes([1, 0.15, 0.02, 0.7])
# add colorbar to the new subplot
grid.fig.colorbar(grid.ax_joint.collections[0], cax=cbar_ax, label="count")
grid.figure.colorbar(
grid.ax_joint.collections[0], cax=cbar_ax, label="Count"
)
if vmin and vmax:
# set limits of the colorbar to the same as for
# the picked particles
Expand All @@ -224,7 +227,7 @@ def true_pose_distribution_plot(
# get the limits of the colorbar
vmin, vmax = grid.ax_joint.collections[0].get_clim()
# add title to the top of the jointplot
grid.fig.suptitle("true particle pose distribution", fontsize=20, y=1.05)
grid.fig.suptitle("Ground-Truth pose distribution", fontsize=21, y=1.05)

return grid, vmin, vmax

Expand Down Expand Up @@ -351,7 +354,7 @@ def picked_pose_distribution_plot(
cmap="RdYlBu_r",
marginal_kws=dict(bins=100, fill=False),
)
grid.fig.set_size_inches(14, 7)
grid.fig.set_size_inches(7, 3.5)
# adjust the x and y ticks to show multiples of pi
grid.ax_joint.set_xticks(
[
Expand All @@ -377,7 +380,8 @@ def picked_pose_distribution_plot(
"\u03C0/2",
"3/4\u03C0",
"\u03C0",
]
],
rotation=45,
)
grid.ax_joint.set_yticks([0, np.pi / 4, np.pi / 2, 3 / 4 * np.pi, np.pi])
grid.ax_joint.set_yticklabels(
Expand All @@ -386,9 +390,11 @@ def picked_pose_distribution_plot(
grid.ax_joint.set_xlabel("Azimuth")
grid.ax_joint.set_ylabel("Tilt")
# add new sublot to the right of the jointplot
cbar_ax = grid.fig.add_axes([1, 0.15, 0.02, 0.7])
cbar_ax = grid.figure.add_axes([1, 0.15, 0.02, 0.7])
# add colorbar to the new subplot
grid.fig.colorbar(grid.ax_joint.collections[0], cax=cbar_ax, label="count")
grid.figure.colorbar(
grid.ax_joint.collections[0], cax=cbar_ax, label="Count"
)
if vmin and vmax:
# set limits of the colorbar to the
# same as for the picked particles
Expand All @@ -397,7 +403,9 @@ def picked_pose_distribution_plot(
# get the limits of the colorbar
vmin, vmax = grid.ax_joint.collections[0].get_clim()
# add title to the top of the jointplot
grid.fig.suptitle("picked particle pose distribution", fontsize=20, y=1.05)
grid.figure.suptitle(
"Picked particle pose distribution", fontsize=21, y=1.05
)

return grid, vmin, vmax

Expand Down
6 changes: 3 additions & 3 deletions src/roodmus/analysis/plot_ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ def plot_defocus_scatter(
ax[0].set_xlabel("defocus truth [$\u212B$]")
ax[1].set_xlabel("defocus truth [$\u212B$]")
ax[0].set_ylabel("defocusU estimated [$\u212B$]")
ax[0].set_title("defocusU")
ax[1].set_title("defocusV")
ax[0].set_title("DefocusU (\u03bcm)")
ax[1].set_title("DefocusV (\u03bcm)")
ax[0].grid(False)
ax[1].grid(False)
# add colorbar legend
Expand All @@ -421,7 +421,7 @@ def plot_defocus_scatter(
)
sm._A = []
cbar = plt.colorbar(sm)
cbar.set_label("micrograph")
cbar.set_label("Micrograph")
fig.tight_layout()
return fig, ax

Expand Down
7 changes: 5 additions & 2 deletions src/roodmus/analysis/plot_picking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,7 +2221,9 @@ def plot_precision(
)
sm._A = []
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label("defocus (\u03bcm)", rotation=270, labelpad=20, fontsize=12)
cbar.set_label("Defocus (\u03bcm)", rotation=270, labelpad=20, fontsize=21)
# change tick labels of the colourbar to use 18 fontsize
cbar.ax.tick_params(labelsize=18)
# add labels
ax.set_xlabel("")
ax.set_ylabel("precision", fontsize=14)
Expand Down Expand Up @@ -2289,7 +2291,8 @@ def plot_recall(
)
sm._A = []
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label("defocus (\u03bcm)", rotation=270, labelpad=20, fontsize=12)
cbar.set_label("Defocus (\u03bcm)", rotation=270, labelpad=20, fontsize=21)
cbar.ax.tick_params(labelsize=18)
# add labels
ax.set_xlabel("")
ax.set_ylabel("recall", fontsize=14)
Expand Down
16 changes: 14 additions & 2 deletions src/roodmus/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def _match_particles(
results_truth: pd.DataFrame,
verbose: bool = False,
enable_tqdm: bool = False,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame,]:
) -> Tuple:
"""When picked and truth dfs are loaded, this can be used to create
dataframes of matched picked particles, matched truth particles,
picked particles not matched to truth particles and finally, truth
Expand Down Expand Up @@ -1580,6 +1580,9 @@ def _match_particles(
matched_truth_dfs = []
unmatched_picked_dfs = []
unmatched_truth_dfs = []
# list to hold all indices of matched truth particles
# for the total number of picked particles
truth_index_for_picked = []

# loop over the metadata files
if isinstance(metadata_filenames, str):
Expand Down Expand Up @@ -1653,6 +1656,7 @@ def _match_particles(
)
# no particle is matched to picked particle
# and we move on to next particle
truth_index_for_picked.append(np.nan)
continue
# check if closest truth particle is within particle
# diameter of picked particle
Expand All @@ -1661,6 +1665,7 @@ def _match_particles(
> self.particle_diameter
):
closest_truth_index.append(np.nan)
truth_index_for_picked.append(np.nan)
continue
# if it is, consider picking successful and allow the
# pickedand truth particle to be associated with each
Expand All @@ -1672,6 +1677,12 @@ def _match_particles(
p_match.append(j)
t_match.append(truth_particle_index)

# to the list of truth indices for picked particles
# add the index of the truth particle in results_truth
truth_index_for_picked.append(
ugraph_truth.index[truth_particle_index]
)

# check whether any truth particles had multiple
# picked particles mapped to them
non_unique_count = len(
Expand All @@ -1697,7 +1708,7 @@ def _match_particles(
ugraph,
)
)
if non_unique_count > 0:
if non_unique_count > 0 and verbose:
print(
"This may cause problems with overwritten assns"
" in truth particles dict!"
Expand Down Expand Up @@ -1775,6 +1786,7 @@ def _match_particles(
matched_truth_df,
unmatched_picked_df,
unmatched_truth_df,
truth_index_for_picked,
)

def _calc_neighbours(self, pos_picked, pos_truth, r):
Expand Down

0 comments on commit 11b542b

Please sign in to comment.