Skip to content

Commit

Permalink
fixed a bug in the load_data class during 1-to-1 particle matching
Browse files Browse the repository at this point in the history
  • Loading branch information
MJoosten committed Jul 9, 2024
1 parent 3f77a5d commit 37507a5
Showing 1 changed file with 51 additions and 15 deletions.
66 changes: 51 additions & 15 deletions src/roodmus/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,40 @@ def get_ugraph_cs(self, metadata_cs: np.recarray) -> List[str]:

if "location/micrograph_path" in metadata_cs.dtype.names:
ugraph_paths = metadata_cs["location/micrograph_path"].tolist()
ugraph_paths = [
os.path.basename(path).decode("utf-8").split("_")[-1]
for path in ugraph_paths
# Cryosparc may add suffixes to the micrograph
# name, such as _patch_aligned_doseweighted.mrc.
# We remove these suffixes to match the micrograph
# names in the config files.
list_of_suffixes = [
"_patch_aligned_doseweighted",
]
ugraph_path_basenames = []
for ugraph_path in ugraph_paths:
ugraph_path_basename = os.path.basename(ugraph_path).decode(
"utf-8"
)
for suffix in list_of_suffixes:
ugraph_path_basename = ugraph_path_basename.replace(
suffix, ""
)
# check that the final part of the basename
# can be converted to an integer
if ugraph_path_basename.split("_")[-1].isdigit():
ugraph_path_basenames.append(
ugraph_path_basename.split("_")[-1]
)
elif ugraph_path_basename.split("_")[-2].isdigit():
# there is some additional suffix which we want to preserve
ugraph_path_basenames.append(
"_".join(ugraph_path_basename.split("_")[-2:])
)
else:
ugraph_path_basenames.append(ugraph_path_basename)
# ugraph_path_basenames.append(ugraph_path_basename.split("_")[-1])
else:
ugraph_paths = []
ugraph_path_basenames = []

return ugraph_paths
return ugraph_path_basenames

@classmethod
def get_uid_cs(self, metadata_cs: np.recarray):
Expand Down Expand Up @@ -1579,26 +1605,33 @@ def _match_particles(
matched_particles, unmatched_picked_particles,
unmatched_truth_particles
"""
# to hold list of dfs (1 entry per ugraph) before concat
matched_picked_dfs = []
matched_truth_dfs = []
unmatched_picked_dfs = []
unmatched_truth_dfs = []
# list to hold all indices of matched truth particles
# for the total number of picked particles
truth_index_for_picked = []

# loop over the metadata files
if not metadata_filenames:
metadata_filenames = results_picking["metadata_filename"].unique()
elif isinstance(metadata_filenames, str):
metadata_filenames = [metadata_filenames]
for midx, metadata_filename in enumerate(metadata_filenames):
# to hold list of dfs (1 entry per ugraph) before concat
matched_picked_dfs = []
matched_truth_dfs = []
unmatched_picked_dfs = []
unmatched_truth_dfs = []
# list to hold all indices of matched truth particles
# for the total number of picked particles
truth_index_for_picked = []

# grab particles for this metadata file only
metafile_picked = results_picking.loc[
results_picking["metadata_filename"] == metadata_filename
]
metafile_truth = results_truth
if verbose:
print(
"number of picked particles for "
+ f"{metadata_filename}: {len(metafile_picked)}"
)
print(f"number of truth particles: {len(metafile_truth)}")

# now find unique ugraphs, loop over whilst computing matches
# and unmatched particles
Expand Down Expand Up @@ -2219,6 +2252,7 @@ def compute_1to1_match_precision(
t_unmatched: pd.DataFrame,
results_truth: pd.DataFrame,
verbose: bool = False,
enable_tqdm: bool = False,
):
"""This function produces another data frame containing the number
of true positives, false positives, false negatives
Expand Down Expand Up @@ -2286,13 +2320,15 @@ def compute_1to1_match_precision(
)

progressbar = tqdm(
total=len(df_truth_grouped.groups.keys()),
total=len(p_match_grouped.groups.keys()),
desc="computing precision",
disable=False, # not verbose,
disable=not enable_tqdm,
)

groupnames: dict_keys = p_match_grouped.groups.keys()
for groupname in groupnames:
if verbose and not enable_tqdm:
print(groupname)
# grab the particles in this ugraph
p_match_in_ugraph = p_match_grouped.get_group(groupname)
TP = len(p_match_in_ugraph)
Expand Down

0 comments on commit 37507a5

Please sign in to comment.