Skip to content

Commit

Permalink
MRG: add abund estimation to manysearch (#302)
Browse files Browse the repository at this point in the history
* add abund estimates

* only calc results with nonzero overlap

* rm eprintln

* update manysearch info

---------

Co-authored-by: C. Titus Brown <[email protected]>
  • Loading branch information
bluegenes and ctb authored Jul 3, 2024
1 parent f6c7c30 commit 990790f
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 39 deletions.
4 changes: 3 additions & 1 deletion doc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ sourmash scripts manysearch queries.zip metagenomes.manifest.csv -o results.csv
```
<!-- We suggest using a manifest CSV for the metagenome collection. -->

The results file here, `query.x.gtdb-reps.csv`, will have 8 columns: `query` and `query_md5`, `match` and `match_md5`, and `containment`, `jaccard`, `max_containment`, and `intersect_hashes`.
The results file here, `query.x.gtdb-reps.csv`, will have the following columns: `query`, `query_md5`, `match_name`, `match_md5`, `containment`, `jaccard`, `max_containment`, `intersect_hashes`, `query_containment_ani`.

If you run `manysearch` _without_ using a rocksdb database, the results file will also have the following columns: `average_abund`, `median_abund`, `std_abund`, `match_containment_ani`, `average_containment_ani`, `max_containment_ani`. If the query sketches were not built with abundance tracking enabled, `average_abund` and `median_abund` will default to `1.0`; `std_abund` will default to `0.0`.


### Running `cluster`
Expand Down
104 changes: 67 additions & 37 deletions src/manysearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
/// database once.
use anyhow::Result;
use rayon::prelude::*;
use stats::{median, stddev};
use std::sync::atomic;
use std::sync::atomic::AtomicUsize;

Expand Down Expand Up @@ -70,44 +71,73 @@ pub fn manysearch(
Ok(against_sig) => {
if let Some(against_mh) = against_sig.minhash() {
for query in query_sketchlist.iter() {
// to do - let user choose?
let calc_abund_stats = query.minhash.track_abundance();

let against_mh_ds = against_mh.downsample_scaled(query.minhash.scaled()).unwrap();
let overlap =
query.minhash.count_common(against_mh, true).unwrap() as f64;
let query_size = query.minhash.size() as f64;
let target_size = against_mh.size() as f64;
let containment_query_in_target = overlap / query_size;
let containment_target_in_query = overlap / target_size;
let max_containment =
containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

let qani = ani_from_containment(
containment_query_in_target,
against_mh.ksize() as f64,
);
let mani = ani_from_containment(
containment_target_in_query,
against_mh.ksize() as f64,
);
let query_containment_ani = Some(qani);
let match_containment_ani = Some(mani);
let average_containment_ani = Some((qani + mani) / 2.);
let max_containment_ani = Some(f64::max(qani, mani));

if containment_query_in_target > threshold {
results.push(SearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against_sig.name(),
containment: containment_query_in_target,
intersect_hashes: overlap as usize,
match_md5: Some(against_sig.md5sum()),
jaccard: Some(jaccard),
max_containment: Some(max_containment),
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
});
query.minhash.count_common(&against_mh_ds, false).unwrap() as f64;

// only calculate results if we have shared hashes
if overlap > 0.0 {
let query_size = query.minhash.size() as f64;
let target_size = against_mh.size() as f64;
let containment_query_in_target = overlap / query_size;
let containment_target_in_query = overlap / target_size;

let max_containment =
containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

let qani = ani_from_containment(
containment_query_in_target,
against_mh.ksize() as f64,
);
let mani = ani_from_containment(
containment_target_in_query,
against_mh.ksize() as f64,
);
let query_containment_ani = Some(qani);
let match_containment_ani = Some(mani);
let average_containment_ani = Some((qani + mani) / 2.);
let max_containment_ani = Some(f64::max(qani, mani));

let (average_abund, median_abund, std_abund) = if calc_abund_stats {
match against_mh_ds.inflated_abundances(&query.minhash) {
Ok((abunds, sum_weighted_overlap)) => {
let average_abund = sum_weighted_overlap as f64 / abunds.len() as f64;
let median_abund = median(abunds.iter().cloned()).unwrap();
let std_abund = stddev(abunds.iter().cloned());
(average_abund, median_abund, std_abund)
}
Err(e) => {
eprintln!("Error calculating abundances for query: {}, against: {}; Error: {}", query.name, against_sig.name(), e);
continue;
}
}
} else {
(1.0, 1.0, 0.0)
};

if containment_query_in_target > threshold {
results.push(SearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against_sig.name(),
containment: containment_query_in_target,
intersect_hashes: overlap as usize,
match_md5: Some(against_sig.md5sum()),
jaccard: Some(jaccard),
max_containment: Some(max_containment),
average_abund,
median_abund,
std_abund,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
});
}
}
}
} else {
Expand Down
4 changes: 4 additions & 0 deletions src/mastiff_manysearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pub fn mastiff_manysearch(
match_md5: None,
jaccard: None,
max_containment: None,
// can't calculate from here -- need to get these from w/in sourmash
average_abund: 1.0,
median_abund: 1.0,
std_abund: 0.0,
query_containment_ani,
match_containment_ani: None,
average_containment_ani: None,
Expand Down
87 changes: 86 additions & 1 deletion src/python/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def test_simple(runtmp, zip_query, zip_against):
assert float(row['match_containment_ani'] == 1.0)
assert float(row['average_containment_ani'] == 1.0)
assert float(row['max_containment_ani'] == 1.0)
assert float(row['average_abund'] == 1.0)
assert float(row['median_abund'] == 1.0)
assert float(row['std_abund'] == 0.0)

else:
# confirm hand-checked numbers
Expand All @@ -90,6 +93,9 @@ def test_simple(runtmp, zip_query, zip_against):
match_ani = float(row['match_containment_ani'])
average_ani = float(row['average_containment_ani'])
max_ani = float(row['max_containment_ani'])
average_abund = float(row['average_abund'])
median_abund = float(row['median_abund'])
std_abund = float(row['std_abund'])

jaccard = round(jaccard, 4)
cont = round(cont, 4)
Expand All @@ -98,7 +104,12 @@ def test_simple(runtmp, zip_query, zip_against):
match_ani = round(match_ani, 4)
average_ani = round(average_ani, 4)
max_ani = round(max_ani, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}", f"{query_ani:.04}", f"{match_ani:.04}", f"{average_ani:.04}", f"{max_ani:.04}")
avg_abund = round(average_abund, 4)
med_abund = round(median_abund, 4)
std_abund = round(std_abund, 4)
print(q, m, f"{jaccard:.04}", f"{cont:.04}", f"{maxcont:.04}",
f"{query_ani:.04}", f"{match_ani:.04}", f"{average_ani:.04}", f"{max_ani:.04}",
f"{avg_abund:.04}", f"{med_abund:.04}", f"{std_abund:.04}")

if q == 'NC_011665.1' and m == 'NC_009661.1':
assert jaccard == 0.3207
Expand All @@ -109,6 +120,9 @@ def test_simple(runtmp, zip_query, zip_against):
assert match_ani == 0.9772
assert average_ani == 0.977
assert max_ani == 0.9772
assert avg_abund == 1.0
assert med_abund == 1.0
assert std_abund == 0.0

if q == 'NC_009661.1' and m == 'NC_011665.1':
assert jaccard == 0.3207
Expand All @@ -119,6 +133,77 @@ def test_simple(runtmp, zip_query, zip_against):
assert match_ani == 0.9768
assert average_ani == 0.977
assert max_ani == 0.9772
assert avg_abund == 1.0
assert med_abund == 1.0
assert std_abund == 0.0


def test_simple_abund(runtmp):
# test with abund sig
query = get_test_data('SRR606249.sig.gz')
against_list = runtmp.output('against.txt')

sig2 = get_test_data('2.fa.sig.gz')
sig47 = get_test_data('47.fa.sig.gz')
sig63 = get_test_data('63.fa.sig.gz')
make_file_list(against_list, [sig2, sig47, sig63])

output = runtmp.output('out.csv')

runtmp.sourmash('scripts', 'manysearch', query, against_list,
'-o', output, '--scaled', '100000', '-k', '31')

assert os.path.exists(output)

df = pandas.read_csv(output)
assert len(df) == 1

dd = df.to_dict(orient='index')
print(dd)

for idx, row in dd.items():
# confirm hand-checked numbers
q = row['query_name'].split()[0]
assert q == "SRR606249"
m = row['match_name'].split()[0]
assert "NC_011665.1" in m
cont = float(row['containment'])
jaccard = float(row['jaccard'])
maxcont = float(row['max_containment'])
intersect_hashes = int(row['intersect_hashes'])
query_ani = float(row['query_containment_ani'])
match_ani = float(row['match_containment_ani'])
average_ani = float(row['average_containment_ani'])
max_ani = float(row['max_containment_ani'])
average_abund = float(row['average_abund'])
median_abund = float(row['median_abund'])
std_abund = float(row['std_abund'])

jaccard = round(jaccard, 4)
cont = round(cont, 4)
maxcont = round(maxcont, 4)
query_ani = round(query_ani, 4)
match_ani = round(match_ani, 4)
average_ani = round(average_ani, 4)
max_ani = round(max_ani, 4)
avg_abund = round(average_abund, 4)
med_abund = round(median_abund, 4)
std_abund = round(std_abund, 4)
print(q, m, f"{jaccard}", f"{cont}", f"{maxcont}",
f"{query_ani}", f"{match_ani}", f"{average_ani}", f"{max_ani}",
f"{avg_abund}", f"{med_abund}", f"{std_abund}")

assert jaccard == 0.0047
assert cont == 0.0105
assert maxcont == 0.0105
assert intersect_hashes == 44
assert query_ani == 0.8632
assert match_ani == 0.8571
assert average_ani == 0.8602
assert max_ani == 0.8632
assert avg_abund == 10.3864
assert med_abund == 10.5
assert std_abund == 6.9322


@pytest.mark.parametrize("zip_query", [False, True])
Expand Down
3 changes: 3 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,9 @@ pub struct SearchResult {
pub match_md5: Option<String>,
pub jaccard: Option<f64>,
pub max_containment: Option<f64>,
pub average_abund: f64,
pub median_abund: f64,
pub std_abund: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub query_containment_ani: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down

0 comments on commit 990790f

Please sign in to comment.