Skip to content

Commit

Permalink
refactor calculate_gather_stats to take ds query & return isect
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb committed Oct 14, 2024
1 parent ceaea39 commit 0eeca48
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
31 changes: 17 additions & 14 deletions src/core/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use getset::{CopyGetters, Getters, Setters};
use log::trace;
use serde::{Deserialize, Serialize};
use stats::{median, stddev};
use std::cmp::max;
use typed_builder::TypedBuilder;

use crate::ani_utils::{ani_ci_from_containment, ani_from_containment};
Expand All @@ -29,6 +28,7 @@ use crate::signature::SigsTrait;
use crate::sketch::minhash::KmerMinHash;
use crate::storage::SigStore;
use crate::Result;
use crate::Error::CannotUpsampleScaled;

#[derive(TypedBuilder, CopyGetters, Getters, Setters, Serialize, Deserialize, Debug, PartialEq)]
pub struct GatherResult {
Expand Down Expand Up @@ -209,7 +209,7 @@ where
#[allow(clippy::too_many_arguments)]
pub fn calculate_gather_stats(
orig_query: &KmerMinHash,
query: KmerMinHash,
query_foo: KmerMinHash,
match_sig: SigStore,
match_size: usize,
gather_result_rank: usize,
Expand All @@ -218,29 +218,31 @@ pub fn calculate_gather_stats(
calc_abund_stats: bool,
calc_ani_ci: bool,
confidence: Option<f64>,
) -> Result<GatherResult> {
) -> Result<(GatherResult, (Vec<u64>, u64))> {
// get match_mh
let match_mh = match_sig.minhash().expect("cannot retrieve sketch");

let max_scaled = max(match_mh.scaled(), query.scaled());
let query = query
.downsample_scaled(max_scaled)
.expect("cannot downsample query");
// it's ok to downsample match, but query is often big and repeated,
// so we do not allow downsampling here.
if match_mh.scaled() > query_foo.scaled() {
return Err(CannotUpsampleScaled);
}

let match_mh = match_mh
.clone()
.downsample_scaled(max_scaled)
.downsample_scaled(query_foo.scaled())
.expect("cannot downsample match");

// calculate intersection
let isect = match_mh
.intersection(&query)
.intersection(&query_foo)
.expect("could not do intersection");
let isect_size = isect.0.len();
trace!("isect_size: {}", isect_size);
trace!("query.size: {}", query.size());
trace!("query.size: {}", query_foo.size());

//bp remaining in subtracted query
let remaining_bp = (query.size() - isect_size) * query.scaled() as usize;
let remaining_bp = (query_foo.size() - isect_size) * query_foo.scaled() as usize;

// stats for this match vs original query
let (intersect_orig, _) = match_mh.intersection_size(orig_query).unwrap();
Expand Down Expand Up @@ -300,7 +302,7 @@ pub fn calculate_gather_stats(
// If abundance, calculate abund-related metrics (vs current query)
if calc_abund_stats {
// take abunds from subtracted query
let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&query) {
let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&query_foo) {
Ok((abunds, unique_weighted_found)) => (abunds, unique_weighted_found),
Err(e) => {
return Err(e);
Expand Down Expand Up @@ -347,7 +349,7 @@ pub fn calculate_gather_stats(
.sum_weighted_found(sum_total_weighted_found)
.total_weighted_hashes(total_weighted_hashes)
.build();
Ok(result)
Ok((result, isect))
}

#[cfg(test)]
Expand Down Expand Up @@ -403,7 +405,7 @@ mod test_calculate_gather_stats {
let gather_result_rank = 0;
let calc_abund_stats = true;
let calc_ani_ci = false;
let result = calculate_gather_stats(
let (result, _isect) = calculate_gather_stats(
&orig_query,
query,
match_sig.into(),
Expand All @@ -416,6 +418,7 @@ mod test_calculate_gather_stats {
None,
)
.unwrap();

// first, print all results
assert_eq!(result.filename(), "match-filename");
assert_eq!(result.name(), "match-name");
Expand Down
31 changes: 17 additions & 14 deletions src/core/src/index/revindex/disk_revindex.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::max;
use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher};
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -393,30 +394,26 @@ impl RevIndexOps for RevIndex {
}

let match_sig = self.collection.sig_for_dataset(dataset_id)?;

// get downsampled minhashes for comparison.
let match_mh = match_sig.minhash().unwrap().clone();
let scaled = query.scaled();

// make downsampled minhashes
let max_scaled = max(match_mh.scaled(), query.scaled());

let match_mh = match_mh
.downsample_scaled(scaled)
.downsample_scaled(max_scaled)
.expect("cannot downsample match");

query = query
.downsample_scaled(max_scaled)
.expect("cannot downsample query");
let query_mh = KmerMinHash::from(query.clone());

// just calculate essentials here
let gather_result_rank = matches.len();

let query_mh = KmerMinHash::from(query.clone());

// grab the specific intersection:
let isect = match_mh
.intersection(&query_mh)
.expect("failed to intersect");
let mut isect_mh = match_mh.clone();
isect_mh.clear();
isect_mh.add_many(&isect.0)?;

// Calculate stats
let gather_result = calculate_gather_stats(
let (gather_result, isect) = calculate_gather_stats(
&orig_query,
query_mh,
match_sig,
Expand All @@ -429,6 +426,12 @@ impl RevIndexOps for RevIndex {
ani_confidence_interval_fraction,
)
.expect("could not calculate gather stats");

// use intersection from calc_gather_stats to make a KmerMinHash.
let mut isect_mh = match_mh.clone();
isect_mh.clear();
isect_mh.add_many(&isect.0)?;

// keep track of the sum weighted found
sum_weighted_found = gather_result.sum_weighted_found();
matches.push(gather_result);
Expand Down

0 comments on commit 0eeca48

Please sign in to comment.