Skip to content

Commit

Permalink
updated plotting and fixed bug in precision calc
Browse files Browse the repository at this point in the history
  • Loading branch information
MJoosten committed Apr 19, 2024
1 parent 00a2060 commit b8c5bb3
Show file tree
Hide file tree
Showing 2 changed files with 434 additions and 60,441 deletions.
60,823 changes: 401 additions & 60,422 deletions paper/figure_2.ipynb

Large diffs are not rendered by default.

52 changes: 33 additions & 19 deletions src/roodmus/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,7 @@ def _match_particles(
results_picking: pd.DataFrame,
results_truth: pd.DataFrame,
verbose: bool = False,
enable_tqdm: bool = False,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame,]:
"""When picked and truth dfs are loaded, this can be used to create
dataframes of matched picked particles, matched truth particles,
Expand Down Expand Up @@ -1596,7 +1597,7 @@ def _match_particles(
progressbar = tqdm(
total=len(ugraph_ids),
desc="computing closest matches",
disable=not verbose,
disable=not enable_tqdm,
)
for ugraph in ugraph_ids:
# grab the positions from picked+truth to match
Expand Down Expand Up @@ -1642,12 +1643,14 @@ def _match_particles(
np.nanargmin(picked_particle)
)
else:
print(
"all nan slice in ugraph {} with picked particle"
" index {} with entries: {}".format(
ugraph, j, picked_particle
if verbose:
print(
"all nan slice in ugraph {}"
" with picked particle"
" index {} with entries: {}".format(
ugraph, j, picked_particle
)
)
)
# no particle is matched to picked particle
# and we move on to next particle
continue
Expand Down Expand Up @@ -1686,13 +1689,14 @@ def _match_particles(
> 1
]
)
print(
"There are {} non-unique picked"
" particles in ugraph {}!".format(
non_unique_count,
ugraph,
if verbose:
print(
"There are {} non-unique picked"
" particles in ugraph {}!".format(
non_unique_count,
ugraph,
)
)
)
if non_unique_count > 0:
print(
"This may cause problems with overwritten assns"
Expand All @@ -1714,13 +1718,13 @@ def _match_particles(
== 0
]
)
if no_truth_match > 0:
if no_truth_match > 0 and verbose:
print(
"There are {} no-truth-match picked particles!".format(
no_truth_match
)
)
if non_unique_count > 0:
if non_unique_count > 0 and verbose:
print(
"This may cause problems with overwritten assns"
" in truth particles dict!"
Expand All @@ -1743,7 +1747,7 @@ def _match_particles(
unmatched_picked_dfs.append(ugraph_picked.iloc[p_unmatched])

# Extract the unmatched truth particles
t_list = np.arange(len(picked_pos_x), dtype=int).tolist()
t_list = np.arange(len(truth_pos_x), dtype=int).tolist()
t_unmatched = list(set(t_list).difference(t_match))
unmatched_truth_dfs.append(ugraph_truth.iloc[t_unmatched])

Expand Down Expand Up @@ -2213,17 +2217,27 @@ def compute_1to1_match_precision(
p_match_in_ugraph = p_match_grouped.get_group(groupname)
TP = len(p_match_in_ugraph)

p_unmatched_in_ugraph = p_unmatched_grouped.get_group(groupname)
FP = len(p_unmatched_in_ugraph)
if groupname in p_unmatched_grouped.groups.keys():
p_unmatched_in_ugraph = p_unmatched_grouped.get_group(
groupname
)
FP = len(p_unmatched_in_ugraph)
else:
FP = 0

"""
t_match_in_ugraph = t_match_grouped.get_group(
groupname[1]
)
"""

t_unmatched_in_ugraph = t_unmatched_grouped.get_group(groupname[1])
FN = len(t_unmatched_in_ugraph)
if groupname[1] in t_unmatched_grouped.groups.keys():
t_unmatched_in_ugraph = t_unmatched_grouped.get_group(
groupname[1]
)
FN = len(t_unmatched_in_ugraph)
else:
FN = 0

truth_particles_in_ugraph = df_truth_grouped.get_group(
groupname[1]
Expand Down

0 comments on commit b8c5bb3

Please sign in to comment.