From 0e2aaf564c59b8bb7ca4b1749420c18b6cce8ba5 Mon Sep 17 00:00:00 2001 From: clintval Date: Tue, 25 Jun 2024 16:38:25 -0700 Subject: [PATCH] Allow the OverlapDetector to be generic over input feature types --- pybedlite/overlap_detector.py | 170 +++++++++++++++++++---- pybedlite/tests/test_overlap_detector.py | 71 +++++++++- 2 files changed, 210 insertions(+), 31 deletions(-) diff --git a/pybedlite/overlap_detector.py b/pybedlite/overlap_detector.py index 61d7ee9..9697671 100644 --- a/pybedlite/overlap_detector.py +++ b/pybedlite/overlap_detector.py @@ -41,12 +41,17 @@ import itertools from pathlib import Path from typing import Dict +from typing import Generic +from typing import Hashable from typing import Iterable from typing import Iterator from typing import List from typing import Optional +from typing import Protocol from typing import Set from typing import Type +from typing import TypeVar +from typing import Union import attr @@ -56,6 +61,52 @@ from pybedlite.bed_source import BedSource +class _Span(Protocol): + """A genomic span with a start and an end. 0-based open-ended.""" + + @property + def start(self) -> int: + """A 0-based start position.""" + + @property + def end(self) -> int: + """A 0-based open-ended position.""" + + +class _GenomicSpanWithChrom(_Span, Hashable, Protocol): + """A genomic feature where reference sequence is accessed with `chrom`.""" + + @property + def chrom(self) -> str: + """A reference sequence name.""" + + +class _GenomicSpanWithContig(_Span, Hashable, Protocol): + """A genomic feature where reference sequence is accessed with `contig`.""" + + @property + def contig(self) -> str: + """A reference sequence name.""" + + +class _GenomicSpanWithRefName(_Span, Hashable, Protocol): + """A genomic feature where reference sequence is accessed with `refname`.""" + + @property + def refname(self) -> str: + """A reference sequence name.""" + + +GenomicSpan = TypeVar( + "GenomicSpan", + bound=Union[_GenomicSpanWithChrom, _GenomicSpanWithContig, _GenomicSpanWithRefName], +) +""" +A genomic feature where the reference sequence name is accessed with any of the 3 most common +property names ("chrom", "contig", "refname"). +""" + + @attr.s(frozen=True, auto_attribs=True) class Interval: """A region mapping to the genome that is 0-based and open-ended @@ -126,36 +177,93 @@ def from_bedrecord(cls: Type["Interval"], record: BedRecord) -> "Interval": ) -class OverlapDetector(Iterable[Interval]): +_GenericGenomicSpan = TypeVar( + "_GenericGenomicSpan", + bound=Union[_GenomicSpanWithChrom, _GenomicSpanWithContig, _GenomicSpanWithRefName], +) +""" +A generic genomic feature where the reference sequence name is accessed with any of the 3 most +common property names ("chrom", "contig", "refname"). This type variable is used for describing the +generic type contained within the :class:`~pybedlite.overlap_detector.OverlapDetector`. +""" + + +class OverlapDetector(Generic[_GenericGenomicSpan], Iterable[_GenericGenomicSpan]): """Detects and returns overlaps between a set of genomic regions and another genomic region. - Since :class:`~pybedlite.overlap_detector.Interval` objects are used both to populate the - overlap detector and to query it, the coordinate system in use is also 0-based open-ended. + The overlap detector may contain any interval-like Python objects that have the following + properties: + + * `chrom` or `contig` or `refname`: The reference sequence name + * `start`: A 0-based start position + * `end`: A 0-based exclusive end position + + Interval-like Python objects may also contain strandedness information which will be used + for sorting them in :func:`~pybedlite.overlap_detector.OverlapDetector.get_overlaps` using + either of the following properties if they are present: + + * `negative (bool)`: Whether or not the feature is negative stranded or not + * `strand (BedStrand)`: The BED strand of the feature The same interval may be added multiple times, but only a single instance will be returned - when querying for overlaps. Intervals with the same coordinates but different names are - treated as different intervals. + when querying for overlaps. This detector is the most efficient when all intervals are added ahead of time. """ - def __init__(self) -> None: + def __init__(self, intervals: Optional[Iterable[_GenericGenomicSpan]] = None) -> None: # A mapping from the contig/chromosome name to the associated interval tree self._refname_to_tree: Dict[str, cr.cgranges] = {} # type: ignore self._refname_to_indexed: Dict[str, bool] = {} - self._refname_to_intervals: Dict[str, List[Interval]] = {} + self._refname_to_intervals: Dict[str, List[_GenericGenomicSpan]] = {} + if intervals is not None: + self.add_all(intervals) - def __iter__(self) -> Iterator[Interval]: + def __iter__(self) -> Iterator[_GenericGenomicSpan]: """Iterates over the intervals in the overlap detector.""" return itertools.chain(*self._refname_to_intervals.values()) - def add(self, interval: Interval) -> None: + @staticmethod + def _reference_name(interval: GenomicSpan) -> str: + """Return the reference name of a given interval.""" + if isinstance(interval, Interval) or hasattr(interval, "refname"): + return interval.refname + elif isinstance(interval, BedRecord) or hasattr(interval, "chrom"): + return interval.chrom + elif hasattr(interval, "contig"): + return interval.contig + else: + raise ValueError( + f"Genomic span is missing a reference sequence name property: {interval}" + ) + + @staticmethod + def _is_negative(interval: GenomicSpan) -> bool: + """Determine if this is a negative stranded interval or not.""" + return ( + (hasattr(interval, "negative") and interval.negative) + or ( + hasattr(interval, "strand") + and isinstance(interval.strand, BedStrand) + and interval.strand is BedStrand.Negative + ) + or ( + hasattr(interval, "strand") + and isinstance(interval.strand, str) + and interval.strand == "-" + ) + ) + + def add(self, interval: _GenericGenomicSpan) -> None: """Adds an interval to this detector. Args: interval: the interval to add to this detector """ - refname = interval.refname + if not isinstance(interval, Hashable): + raise ValueError(f"Interval feature is not hashable but should be: {interval}") + + refname = self._reference_name(interval) if refname not in self._refname_to_tree: self._refname_to_tree[refname] = cr.cgranges() # type: ignore self._refname_to_indexed[refname] = False @@ -168,13 +276,13 @@ def add(self, interval: Interval) -> None: # Add the interval to the tree tree = self._refname_to_tree[refname] - tree.add(interval.refname, interval.start, interval.end, interval_idx) + tree.add(refname, interval.start, interval.end, interval_idx) # Flag this tree as needing to be indexed after adding a new interval, but defer # indexing self._refname_to_indexed[refname] = False - def add_all(self, intervals: Iterable[Interval]) -> None: + def add_all(self, intervals: Iterable[_GenericGenomicSpan]) -> None: """Adds one or more intervals to this detector. Args: @@ -183,7 +291,7 @@ def add_all(self, intervals: Iterable[Interval]) -> None: for interval in intervals: self.add(interval) - def overlaps_any(self, interval: Interval) -> bool: + def overlaps_any(self, interval: GenomicSpan) -> bool: """Determines whether the given interval overlaps any interval in this detector. Args: @@ -193,20 +301,21 @@ def overlaps_any(self, interval: Interval) -> bool: True if and only if the given interval overlaps with any interval in this detector. """ - tree = self._refname_to_tree.get(interval.refname) + refname = self._reference_name(interval) + tree = self._refname_to_tree.get(refname) if tree is None: return False else: - if not self._refname_to_indexed[interval.refname]: + if not self._refname_to_indexed[refname]: tree.index() try: - next(iter(tree.overlap(interval.refname, interval.start, interval.end))) + next(iter(tree.overlap(refname, interval.start, interval.end))) except StopIteration: return False else: return True - def get_overlaps(self, interval: Interval) -> List[Interval]: + def get_overlaps(self, interval: GenomicSpan) -> List[_GenericGenomicSpan]: """Returns any intervals in this detector that overlap the given interval. Args: @@ -216,23 +325,30 @@ def get_overlaps(self, interval: Interval) -> List[Interval]: The list of intervals in this detector that overlap the given interval, or the empty list if no overlaps exist. The intervals will be return in ascending genomic order. """ - tree = self._refname_to_tree.get(interval.refname) + refname = self._reference_name(interval) + tree = self._refname_to_tree.get(refname) if tree is None: return [] else: - if not self._refname_to_indexed[interval.refname]: + if not self._refname_to_indexed[refname]: tree.index() - ref_intervals: List[Interval] = self._refname_to_intervals[interval.refname] + ref_intervals: List[_GenericGenomicSpan] = self._refname_to_intervals[refname] # NB: only return unique instances of intervals - intervals: Set[Interval] = { + intervals: Set[_GenericGenomicSpan] = { ref_intervals[index] - for _, _, index in tree.overlap(interval.refname, interval.start, interval.end) + for _, _, index in tree.overlap(refname, interval.start, interval.end) } return sorted( - intervals, key=lambda intv: (intv.start, intv.end, intv.negative, intv.name) + intervals, + key=lambda intv: ( + intv.start, + intv.end, + self._is_negative(intv), + self._reference_name(intv), + ), ) - def get_enclosing_intervals(self, interval: Interval) -> List[Interval]: + def get_enclosing_intervals(self, interval: GenomicSpan) -> List[_GenericGenomicSpan]: """Returns the set of intervals in this detector that wholly enclose the query interval. i.e. query.start >= target.start and query.end <= target.end. @@ -245,7 +361,7 @@ def get_enclosing_intervals(self, interval: Interval) -> List[Interval]: results = self.get_overlaps(interval) return [i for i in results if interval.start >= i.start and interval.end <= i.end] - def get_enclosed(self, interval: Interval) -> List[Interval]: + def get_enclosed(self, interval: GenomicSpan) -> List[_GenericGenomicSpan]: """Returns the set of intervals in this detector that are enclosed by the query interval. I.e. target.start >= query.start and target.end <= query.end. @@ -267,9 +383,9 @@ def from_bed(cls, path: Path) -> "OverlapDetector": Returns: An overlap detector for the regions in the BED file. """ - detector = OverlapDetector() + detector: OverlapDetector[BedRecord] = OverlapDetector() for region in BedSource(path): - detector.add(Interval.from_bedrecord(region)) + detector.add(region) return detector diff --git a/pybedlite/tests/test_overlap_detector.py b/pybedlite/tests/test_overlap_detector.py index dae34a9..aa0a143 100644 --- a/pybedlite/tests/test_overlap_detector.py +++ b/pybedlite/tests/test_overlap_detector.py @@ -1,6 +1,11 @@ """Tests for :py:mod:`~pybedlite.overlap_detector`""" +from dataclasses import dataclass from typing import List +from typing import TypeAlias +from typing import Union + +import pytest from pybedlite.bed_record import BedRecord from pybedlite.bed_record import BedStrand @@ -9,7 +14,7 @@ def run_test(targets: List[Interval], query: Interval, results: List[Interval]) -> None: - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() # Use add_all() to covert itself and add() detector.add_all(intervals=targets) # Test overlaps_any() @@ -113,7 +118,7 @@ def test_get_enclosing_intervals() -> None: d = Interval("1", 15, 19) e = Interval("1", 16, 20) - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() detector.add_all([a, b, c, d, e]) assert detector.get_enclosing_intervals(Interval("1", 10, 100)) == [a] @@ -128,7 +133,7 @@ def test_get_enclosed() -> None: c = Interval("1", 18, 19) d = Interval("1", 50, 99) - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() detector.add_all([a, b, c, d]) assert detector.get_enclosed(Interval("1", 1, 250)) == [a, b, c, d] @@ -145,7 +150,7 @@ def test_iterable() -> None: d = Interval("1", 15, 19) e = Interval("1", 16, 20) - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() detector.add_all([a]) assert list(detector) == [a] detector.add_all([a, b, c, d, e]) @@ -188,3 +193,61 @@ def test_construction_from_interval(bed_records: List[BedRecord]) -> None: assert new_record.strand is BedStrand.Positive else: assert new_record.strand is record.strand + + +def test_arbitrary_interval_types() -> None: + """ + Test that an overlap detector can receive different interval-like objects and query them too. + """ + + @dataclass(eq=True, frozen=True) + class ChromFeature: + chrom: str + start: int + end: int + + @dataclass(eq=True, frozen=True) + class ContigFeature: + contig: str + start: int + end: int + + @dataclass(eq=True, frozen=True) + class RefnameFeature: + refname: str + start: int + end: int + + # Create minimal features of all supported structural types + chrom_feature = ChromFeature(chrom="chr1", start=0, end=30) + contig_feature = ContigFeature(contig="chr1", start=10, end=40) + refname_feature = RefnameFeature(refname="chr1", start=20, end=50) + + # Setup an overlap detector to hold all the features we care about + AllKinds: TypeAlias = Union[ChromFeature, ContigFeature, RefnameFeature] + features: List[AllKinds] = [chrom_feature, contig_feature, refname_feature] + detector: OverlapDetector[AllKinds] = OverlapDetector(features) + + # Query the overlap detector with yet another type + assert detector.get_overlaps(Interval("chr1", 0, 1)) == [chrom_feature] + assert detector.get_overlaps(Interval("chr1", 25, 26)) == [ + chrom_feature, + contig_feature, + refname_feature, + ] + assert detector.get_overlaps(Interval("chr1", 45, 46)) == [refname_feature] + + +def test_the_overlap_detector_wont_accept_a_non_hashable_feature() -> None: + """ + Test that an overlap detector will not accept a non-hashable feature. + """ + + @dataclass # A dataclass missing both `eq` and `frozen` does not implement __hash__. + class ChromFeature: + chrom: str + start: int + end: int + + with pytest.raises(ValueError): + OverlapDetector([ChromFeature(chrom="chr1", start=0, end=30)])