Skip to content

Commit

Permalink
MRG: improve downsampling behavior on KmerMinHash; fix `RevIndex::g…
Browse files Browse the repository at this point in the history
…ather` bug around `scaled`. (#3342)

This PR does five things:

First, it swaps the implementation of `KmerMinHash::downsample_max_hash`
with `KmerMinHash::downsample_scaled`, and the same for
`KmerMinHashBTree`. Previously a call to `downsample_scaled` calculated
the right `max_hash` from `scaled`, then called `downsample_max_hash`,
which then converted `max_hash` back to `scaled`. This reverses the
logic so that (slightly) less work is done and, more importantly, the
code is a bit more straightforward.

Second, it changes the `downsample_*` functions so that they do not
downsample when no downsampling is needed. As part of this the method
signatures are changed to take an object, rather than a reference. This
lets the functions return an unmodified `KmerMinHash` when no
downsampling is needed.

Third, it turns out the `downsample_*` functions didn't check to make
sure that the new `scaled` value was larger than the old one, i.e. they
didn't prevent upsampling. That check was added and a new error,
`CannotUpsampleScaled`, was added to sourmash core.

Fourth, this uncovered a bug in `RevIndex::gather` where the query was
downsampled to the match, even when the match was lower scaled. This PR
rejiggers the code so that downsampling is done appropriately in the
`gather` and `calculate_gather_stats`. Since `RevIndex::gather` isn't
used in the the sourmash CLI, the bug only presented in the test suite
and in the branchwater plugin; see
sourmash-bio/sourmash_plugin_branchwater#468
and
sourmash-bio/sourmash_plugin_branchwater#467,
where a fastmultigather test had to be fixed because of the incorrect
scaled values output by `RevIndex::gather`.

Fifth, it includes #3348
from @luizirber, which adds a `Signature::try_into()` to `KmerMinHash`
to support the elimination of some clones.

Because of the method signature change for the `downsample_*` functions,
the sourmash-core version needs to be bumped to a new major version,
0.16.0.

It's been a fun journey! 😅 

Fixes #3343

Some notes on further changes and performance implications:

As a consequence of the `RevIndex::gather` changes, redundant
downsampling has to be done in `RevIndex::gather` and
`calculate_gather_stats`, unless we want to change the method signature
of `calculate_gather_stats`. I decided the PR was big enough that I
didn't want to do that in addition. It should not affect most use cases
where `scaled` is the same, and we will see if it results in any
slowdowns over in the branchwater plugin. See
#3196 for an issue on all
of this.

We could also just insist that the query scaled is the one to pay
attention to, per #2951.
This would simplify the code in Python-land as well.

Overall, the performance implications of this PR are not clear.
Previously downsampling was being done even when it wasn't needed, so
this may speed things up quite a lot for our typical use case! On the
other hand, redundant downsampling will happen in cases where there are
scaled mismatches. We just need to benchmark it, I think.

Some preliminary benchmarking reported in
sourmash-bio/sourmash_plugin_branchwater#430 (comment)
suggests that fastgather is now much more memory effficient 🎉 so that's
good!

TODO:
- [x] resolve the scaled mismatch stuff. do we return an `Err` or what
if the downsampling can't be performed?
- [x] update PR description
- [x] add more tests for downsampling, and maybe for gather
- [x] play with this code over in the branchwater plugin too!
sourmash-bio/sourmash_plugin_branchwater#467

---------

Co-authored-by: Luiz Irber <[email protected]>
  • Loading branch information
ctb and luizirber authored Oct 15, 2024
1 parent 1a6312b commit 7d11173
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions include/sourmash.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ enum SourmashErrorCode {
SOURMASH_ERROR_CODE_NON_EMPTY_MIN_HASH = 106,
SOURMASH_ERROR_CODE_MISMATCH_NUM = 107,
SOURMASH_ERROR_CODE_NEEDS_ABUNDANCE_TRACKING = 108,
SOURMASH_ERROR_CODE_CANNOT_UPSAMPLE_SCALED = 109,
SOURMASH_ERROR_CODE_NO_MIN_HASH_FOUND = 110,
SOURMASH_ERROR_CODE_EMPTY_SIGNATURE = 111,
SOURMASH_ERROR_CODE_MULTIPLE_SKETCHES_FOUND = 112,
SOURMASH_ERROR_CODE_INVALID_DNA = 1101,
SOURMASH_ERROR_CODE_INVALID_PROT = 1102,
SOURMASH_ERROR_CODE_INVALID_CODON_LENGTH = 1103,
Expand Down
2 changes: 1 addition & 1 deletion src/core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sourmash"
version = "0.15.2"
version = "0.16.0"
authors = ["Luiz Irber <[email protected]>", "N. Tessa Pierce-Ward <[email protected]>"]
description = "tools for comparing biological sequences with k-mer sketches"
repository = "https://github.com/sourmash-bio/sourmash"
Expand Down
1 change: 1 addition & 0 deletions src/core/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ mod test {
use crate::prelude::Select;
use crate::selection::Selection;
use crate::signature::Signature;
#[cfg(all(feature = "branchwater", not(target_arch = "wasm32")))]
use crate::Result;

#[test]
Expand Down
20 changes: 20 additions & 0 deletions src/core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub enum SourmashError {
#[error("internal error: {message:?}")]
Internal { message: String },

#[error("new scaled smaller than previous; cannot upsample")]
CannotUpsampleScaled,

#[error("must have same num: {n1} != {n2}")]
MismatchNum { n1: u32, n2: u32 },

Expand All @@ -28,6 +31,15 @@ pub enum SourmashError {
#[error("sketch needs abundance for this operation")]
NeedsAbundanceTracking,

#[error("Expected a MinHash sketch in this signature")]
NoMinHashFound,

#[error("Empty signature")]
EmptySignature,

#[error("Multiple sketches found, expected one")]
MultipleSketchesFound,

#[error("Invalid hash function: {function:?}")]
InvalidHashFunction { function: String },

Expand Down Expand Up @@ -104,6 +116,10 @@ pub enum SourmashErrorCode {
NonEmptyMinHash = 1_06,
MismatchNum = 1_07,
NeedsAbundanceTracking = 1_08,
CannotUpsampleScaled = 1_09,
NoMinHashFound = 1_10,
EmptySignature = 1_11,
MultipleSketchesFound = 1_12,
// Input sequence errors
InvalidDNA = 11_01,
InvalidProt = 11_02,
Expand Down Expand Up @@ -132,6 +148,7 @@ impl SourmashErrorCode {
match error {
SourmashError::Internal { .. } => SourmashErrorCode::Internal,
SourmashError::Panic { .. } => SourmashErrorCode::Panic,
SourmashError::CannotUpsampleScaled { .. } => SourmashErrorCode::CannotUpsampleScaled,
SourmashError::MismatchNum { .. } => SourmashErrorCode::MismatchNum,
SourmashError::NeedsAbundanceTracking { .. } => {
SourmashErrorCode::NeedsAbundanceTracking
Expand All @@ -142,6 +159,9 @@ impl SourmashErrorCode {
SourmashError::MismatchSeed => SourmashErrorCode::MismatchSeed,
SourmashError::MismatchSignatureType => SourmashErrorCode::MismatchSignatureType,
SourmashError::NonEmptyMinHash { .. } => SourmashErrorCode::NonEmptyMinHash,
SourmashError::NoMinHashFound => SourmashErrorCode::NoMinHashFound,
SourmashError::EmptySignature => SourmashErrorCode::EmptySignature,
SourmashError::MultipleSketchesFound => SourmashErrorCode::MultipleSketchesFound,
SourmashError::InvalidDNA { .. } => SourmashErrorCode::InvalidDNA,
SourmashError::InvalidProt { .. } => SourmashErrorCode::InvalidProt,
SourmashError::InvalidCodonLength { .. } => SourmashErrorCode::InvalidCodonLength,
Expand Down
17 changes: 14 additions & 3 deletions src/core/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use getset::{CopyGetters, Getters, Setters};
use log::trace;
use serde::{Deserialize, Serialize};
use stats::{median, stddev};
use std::cmp::max;
use typed_builder::TypedBuilder;

use crate::ani_utils::{ani_ci_from_containment, ani_from_containment};
Expand Down Expand Up @@ -205,7 +206,6 @@ where
}
}

// note all mh should be selected/downsampled prior to being passed in here.
#[allow(clippy::too_many_arguments)]
pub fn calculate_gather_stats(
orig_query: &KmerMinHash,
Expand All @@ -220,10 +220,21 @@ pub fn calculate_gather_stats(
confidence: Option<f64>,
) -> Result<GatherResult> {
// get match_mh
let match_mh = match_sig.minhash().unwrap();
let match_mh = match_sig.minhash().expect("cannot retrieve sketch");

let max_scaled = max(match_mh.scaled(), query.scaled());
let query = query
.downsample_scaled(max_scaled)
.expect("cannot downsample query");
let match_mh = match_mh
.clone()
.downsample_scaled(max_scaled)
.expect("cannot downsample match");

// calculate intersection
let isect = match_mh.intersection(&query)?;
let isect = match_mh
.intersection(&query)
.expect("could not do intersection");
let isect_size = isect.0.len();
trace!("isect_size: {}", isect_size);
trace!("query.size: {}", query.size());
Expand Down
20 changes: 12 additions & 8 deletions src/core/src/index/revindex/disk_revindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ impl RevIndexOps for RevIndex {
let mut query = KmerMinHashBTree::from(orig_query.clone());
let mut sum_weighted_found = 0;
let _selection = selection.unwrap_or_else(|| self.collection.selection());
let mut orig_query_ds = orig_query.clone();
let total_weighted_hashes = orig_query.sum_abunds();

// or set this with user --track-abundance?
Expand All @@ -393,29 +392,33 @@ impl RevIndexOps for RevIndex {
break;
}

// this should downsample mh for us
let match_sig = self.collection.sig_for_dataset(dataset_id)?;

// get downsampled minhashes for comparison.
let match_mh = match_sig.minhash().unwrap().clone();
query = query.downsample_scaled(match_mh.scaled())?;
orig_query_ds = orig_query_ds.downsample_scaled(match_mh.scaled())?;
let scaled = query.scaled();

let match_mh = match_mh
.downsample_scaled(scaled)
.expect("cannot downsample match");

// just calculate essentials here
let gather_result_rank = matches.len();

let query_mh = KmerMinHash::from(query.clone());

// grab the specific intersection:
let isect = match_mh.intersection(&query_mh)?;
let isect = match_mh
.intersection(&query_mh)
.expect("failed to intersect");
let mut isect_mh = match_mh.clone();
isect_mh.clear();
isect_mh.add_many(&isect.0)?;

// Calculate stats
let gather_result = calculate_gather_stats(
&orig_query_ds,
KmerMinHash::from(query.clone()),
&orig_query,
query_mh,
match_sig,
match_size,
gather_result_rank,
Expand All @@ -424,7 +427,8 @@ impl RevIndexOps for RevIndex {
calc_abund_stats,
calc_ani_ci,
ani_confidence_interval_fraction,
)?;
)
.expect("could not calculate gather stats");
// keep track of the sum weighted found
sum_weighted_found = gather_result.sum_weighted_found();
matches.push(gather_result);
Expand Down
18 changes: 10 additions & 8 deletions src/core/src/index/revindex/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,14 +903,16 @@ mod test {

let (counter, query_colors, hash_to_color) = index.prepare_gather_counters(&query);

let matches_external = index.gather(
counter,
query_colors,
hash_to_color,
0,
&query,
Some(selection.clone()),
)?;
let matches_external = index
.gather(
counter,
query_colors,
hash_to_color,
0,
&query,
Some(selection.clone()),
)
.expect("failed to gather!");

{
let mut index = index;
Expand Down
24 changes: 23 additions & 1 deletion src/core/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ impl Select for Signature {
// TODO: also account for LargeMinHash
if let Sketch::MinHash(mh) = sketch {
if (mh.scaled() as u32) < sel_scaled {
*sketch = Sketch::MinHash(mh.downsample_scaled(sel_scaled as u64)?);
*sketch = Sketch::MinHash(mh.clone().downsample_scaled(sel_scaled as u64)?);
}
}
}
Expand Down Expand Up @@ -887,6 +887,28 @@ impl PartialEq for Signature {
}
}

impl TryInto<KmerMinHash> for Signature {
type Error = Error;

fn try_into(self) -> Result<KmerMinHash, Error> {
match self.signatures.len() {
1 => self
.signatures
.into_iter()
.find_map(|sk| {
if let Sketch::MinHash(mh) = sk {
Some(mh)
} else {
None
}
})
.ok_or_else(|| Error::NoMinHashFound),
0 => Err(Error::EmptySignature),
_ => Err(Error::MultipleSketchesFound),
}
}
}

#[cfg(test)]
mod test {
use std::fs::File;
Expand Down
Loading

0 comments on commit 7d11173

Please sign in to comment.