Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick fix to precision recall #63

Merged
merged 6 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Currently, five metric operations are supported:
3. mp-value
4. Grit
5. Enrichment
6. Hit@k

## Demos

Expand Down
1 change: 1 addition & 0 deletions cytominer_eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def evaluate(
metric_result = precision_recall(
similarity_melted_df=similarity_melted_df,
replicate_groups=replicate_groups,
groupby_columns=groupby_columns,
k=precision_recall_k,
)
elif operation == "grit":
Expand Down
22 changes: 14 additions & 8 deletions cytominer_eval/operations/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
def precision_recall(
similarity_melted_df: pd.DataFrame,
replicate_groups: List[str],
groupby_columns: List[str],
k: Union[int, List[int]],
) -> pd.DataFrame:
"""Determine the precision and recall at k for all unique replicate groups
"""Determine the precision and recall at k for all unique groupby_columns samples
based on a predefined similarity metric (see cytominer_eval.transform.metric_melt)

Parameters
Expand All @@ -26,15 +27,20 @@ def precision_recall(
samples. Importantly, it must follow the exact structure as output from
:py:func:`cytominer_eval.transform.transform.metric_melt`.
replicate_groups : List
a list of metadata column names in the original profile dataframe to use as
replicate columns.
a list of metadata column names in the original profile dataframe to use as replicate columns.
groupby_columns : List of str
Column by which the similarity matrix is grouped and by which the precision/recall is calculated.
For example, if groupby_column = Metadata_sample then the precision is calculated for each sample.
Calculating the precision by sample is the default
but it is mathematically not incorrect to calculate the precision at the MOA level.
This is just less intuitive to understand.
k : List of ints or int
an integer indicating how many pairwise comparisons to threshold.

Returns
-------
pandas.DataFrame
precision and recall metrics for all replicate groups given k
precision and recall metrics for all groupby_column groups given k
"""
# Determine pairwise replicates and make sure to sort based on the metric!
similarity_melted_df = assign_replicates(
Expand All @@ -46,9 +52,9 @@ def precision_recall(

# Extract out specific columns
pair_ids = set_pair_ids()
replicate_group_cols = [
groupby_cols_suffix = [
"{x}{suf}".format(x=x, suf=pair_ids[list(pair_ids)[0]]["suffix"])
for x in replicate_groups
for x in groupby_columns
]
# iterate over all k
precision_recall_df = pd.DataFrame()
Expand All @@ -57,11 +63,11 @@ def precision_recall(
for k_ in k:
# Calculate precision and recall for all groups
precision_recall_df_at_k = similarity_melted_df.groupby(
replicate_group_cols
groupby_cols_suffix
).apply(lambda x: calculate_precision_recall(x, k=k_))
precision_recall_df = precision_recall_df.append(precision_recall_df_at_k)

# Rename the columns back to the replicate groups provided
rename_cols = dict(zip(replicate_group_cols, replicate_groups))
rename_cols = dict(zip(groupby_cols_suffix, groupby_columns))

return precision_recall_df.reset_index().rename(rename_cols, axis="columns")
5 changes: 5 additions & 0 deletions cytominer_eval/tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def test_evaluate_precision_recall():
},
}

gene_groupby_columns = ["Metadata_pert_name"]
compound_groupby_columns = ["Metadata_broad_sample"]

for k in ks:

# first test the function with k = float, later we test with k = list of floats
Expand All @@ -140,6 +143,7 @@ def test_evaluate_precision_recall():
features=gene_features,
meta_features=gene_meta_features,
replicate_groups=gene_groups,
groupby_columns=gene_groupby_columns,
operation="precision_recall",
similarity_metric="pearson",
precision_recall_k=k,
Expand All @@ -159,6 +163,7 @@ def test_evaluate_precision_recall():
features=compound_features,
meta_features=compound_meta_features,
replicate_groups=["Metadata_broad_sample"],
groupby_columns=compound_groupby_columns,
operation="precision_recall",
similarity_metric="pearson",
precision_recall_k=[k],
Expand Down
19 changes: 10 additions & 9 deletions cytominer_eval/tests/test_operations/test_precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import random
import pytest
import pathlib
import tempfile
import numpy as np
import pandas as pd


from cytominer_eval.transform import metric_melt
from cytominer_eval.operations import precision_recall

random.seed(123)
tmpdir = tempfile.gettempdir()
random.seed(42)

# Load CRISPR dataset
example_file = "SQ00014610_normalized_feature_select.csv.gz"
Expand All @@ -37,32 +34,36 @@

replicate_groups = ["Metadata_gene_name", "Metadata_cell_line"]

groupby_columns = ["Metadata_pert_name"]


def test_precision_recall():
result_list = precision_recall(
similarity_melted_df=similarity_melted_df,
replicate_groups=replicate_groups,
groupby_columns=groupby_columns,
k=[5, 10],
)

result_int = precision_recall(
similarity_melted_df=similarity_melted_df,
replicate_groups=replicate_groups,
groupby_columns=groupby_columns,
k=5,
)

assert len(result_list.k.unique()) == 2
assert result_list.k.unique()[0] == 5

# ITGAV has a really strong profile
# ITGAV-1 has a really strong profile
assert (
result_list.sort_values(by="recall", ascending=False)
.reset_index(drop=True)
.iloc[0, :]
.Metadata_gene_name
== "ITGAV"
.Metadata_pert_name
== "ITGAV-1"
)

assert all(x in result_list.columns for x in replicate_groups)
assert all(x in result_list.columns for x in groupby_columns)

assert result_int.equals(result_list.query("k == 5"))