diff --git a/pybedlite/overlap_detector.py b/pybedlite/overlap_detector.py index 61d7ee9..0982419 100644 --- a/pybedlite/overlap_detector.py +++ b/pybedlite/overlap_detector.py @@ -41,12 +41,17 @@ import itertools from pathlib import Path from typing import Dict +from typing import Generic +from typing import Hashable from typing import Iterable from typing import Iterator from typing import List from typing import Optional +from typing import Protocol from typing import Set from typing import Type +from typing import TypeVar +from typing import Union import attr @@ -56,6 +61,76 @@ from pybedlite.bed_source import BedSource +class Span(Protocol): + """A genomic span with a start and an end. 0-based open-ended.""" + @property + def start(self) -> int: + """A 0-based start position.""" + + @property + def end(self) -> int: + """A 0-based open-ended position.""" + + +class LocatableTypeWithChrom(Span, Hashable, Protocol): + """A genomic feature where reference sequence is accessed with `chrom`.""" + + @property + def chrom(self) -> str: + """A reference sequence name.""" + + +class LocatableTypeWithContig(Span, Hashable, Protocol): + """A genomic feature where reference sequence is accessed with `contig`.""" + + @property + def contig(self) -> str: + """A reference sequence name.""" + + +class LocatableTypeWithRefName(Span, Hashable, Protocol): + """A genomic feature where reference sequence is accessed with `refname`.""" + + @property + def refname(self) -> str: + """A reference sequence name.""" + + +Locatable = TypeVar( + "Locatable", + bound=Union[LocatableTypeWithChrom, LocatableTypeWithContig, LocatableTypeWithRefName], +) +"""A genomic feature where reference sequence is accessed with the 3 most common words.""" + +LocatableToCompare = TypeVar( + "LocatableToCompare", + bound=Union[LocatableTypeWithChrom, LocatableTypeWithContig, LocatableTypeWithRefName], +) +"""A genomic feature where reference sequence is accessed with the 3 most common words.""" + + +def _reference_name(locatable: Locatable) -> str: + """Return the reference name of a given locatable.""" + if hasattr(locatable, "refname"): + return locatable.refname + elif hasattr(locatable, "chrom"): + return locatable.chrom + elif hasattr(locatable, "contig"): + return locatable.contig + else: + raise ValueError(f"Locatable is missing a reference sequence name property: {locatable}") + + +def _is_negative(locatable: Locatable) -> bool: + """Attempt to determine if this is a negative stranded locatable or not.""" + if hasattr(locatable, "negative"): + return locatable.negative + elif hasattr(locatable, "strand") and isinstance(locatable.strand, BedStrand): + return locatable.strand is BedStrand.Negative + else: + return False + + @attr.s(frozen=True, auto_attribs=True) class Interval: """A region mapping to the genome that is 0-based and open-ended @@ -126,36 +201,51 @@ def from_bedrecord(cls: Type["Interval"], record: BedRecord) -> "Interval": ) -class OverlapDetector(Iterable[Interval]): +class OverlapDetector(Generic[Locatable], Iterable[Locatable]): """Detects and returns overlaps between a set of genomic regions and another genomic region. - Since :class:`~pybedlite.overlap_detector.Interval` objects are used both to populate the - overlap detector and to query it, the coordinate system in use is also 0-based open-ended. + The overlap detector may contain any interval-like Python objects that have the following + properties: + + * `chrom` or `contig` or `refname`: The reference sequence name + * `start`: A 0-based start position + * `end`: A 0-based exclusive end position + + Interval-like Python objects may also contain strandedness information which will be used + for sorting them in :func:`~pybedlite.overlap_detector.OverlapDetector.get_overlaps` using + either of the following properties if they are present: + + * `negative (bool)`: Whether or not the feature is negative stranded or not + * `strand (BedStrand)`: The BED strand of the feature The same interval may be added multiple times, but only a single instance will be returned - when querying for overlaps. Intervals with the same coordinates but different names are - treated as different intervals. + when querying for overlaps. This detector is the most efficient when all intervals are added ahead of time. """ - def __init__(self) -> None: + def __init__(self, intervals: Optional[Iterable[Locatable]] = None) -> None: # A mapping from the contig/chromosome name to the associated interval tree self._refname_to_tree: Dict[str, cr.cgranges] = {} # type: ignore self._refname_to_indexed: Dict[str, bool] = {} - self._refname_to_intervals: Dict[str, List[Interval]] = {} + self._refname_to_intervals: Dict[str, List[Locatable]] = {} + if intervals is not None: + self.add_all(intervals) - def __iter__(self) -> Iterator[Interval]: + def __iter__(self) -> Iterator[Locatable]: """Iterates over the intervals in the overlap detector.""" return itertools.chain(*self._refname_to_intervals.values()) - def add(self, interval: Interval) -> None: + def add(self, interval: Locatable) -> None: """Adds an interval to this detector. Args: interval: the interval to add to this detector """ - refname = interval.refname + if not isinstance(interval, Hashable): + raise ValueError(f"Interval feature is not hashable but should be: {interval}") + + refname = _reference_name(interval) if refname not in self._refname_to_tree: self._refname_to_tree[refname] = cr.cgranges() # type: ignore self._refname_to_indexed[refname] = False @@ -168,13 +258,13 @@ def add(self, interval: Interval) -> None: # Add the interval to the tree tree = self._refname_to_tree[refname] - tree.add(interval.refname, interval.start, interval.end, interval_idx) + tree.add(refname, interval.start, interval.end, interval_idx) # Flag this tree as needing to be indexed after adding a new interval, but defer # indexing self._refname_to_indexed[refname] = False - def add_all(self, intervals: Iterable[Interval]) -> None: + def add_all(self, intervals: Iterable[Locatable]) -> None: """Adds one or more intervals to this detector. Args: @@ -183,7 +273,7 @@ def add_all(self, intervals: Iterable[Interval]) -> None: for interval in intervals: self.add(interval) - def overlaps_any(self, interval: Interval) -> bool: + def overlaps_any(self, interval: LocatableToCompare) -> bool: """Determines whether the given interval overlaps any interval in this detector. Args: @@ -193,20 +283,21 @@ def overlaps_any(self, interval: Interval) -> bool: True if and only if the given interval overlaps with any interval in this detector. """ - tree = self._refname_to_tree.get(interval.refname) + refname = _reference_name(interval) + tree = self._refname_to_tree.get(refname) if tree is None: return False else: - if not self._refname_to_indexed[interval.refname]: + if not self._refname_to_indexed[refname]: tree.index() try: - next(iter(tree.overlap(interval.refname, interval.start, interval.end))) + next(iter(tree.overlap(refname, interval.start, interval.end))) except StopIteration: return False else: return True - def get_overlaps(self, interval: Interval) -> List[Interval]: + def get_overlaps(self, interval: LocatableToCompare) -> List[Locatable]: """Returns any intervals in this detector that overlap the given interval. Args: @@ -216,23 +307,30 @@ def get_overlaps(self, interval: Interval) -> List[Interval]: The list of intervals in this detector that overlap the given interval, or the empty list if no overlaps exist. The intervals will be return in ascending genomic order. """ - tree = self._refname_to_tree.get(interval.refname) + refname = _reference_name(interval) + tree = self._refname_to_tree.get(refname) if tree is None: return [] else: - if not self._refname_to_indexed[interval.refname]: + if not self._refname_to_indexed[refname]: tree.index() - ref_intervals: List[Interval] = self._refname_to_intervals[interval.refname] + ref_intervals: List[Locatable] = self._refname_to_intervals[refname] # NB: only return unique instances of intervals - intervals: Set[Interval] = { + intervals: Set[Locatable] = { ref_intervals[index] - for _, _, index in tree.overlap(interval.refname, interval.start, interval.end) + for _, _, index in tree.overlap(refname, interval.start, interval.end) } return sorted( - intervals, key=lambda intv: (intv.start, intv.end, intv.negative, intv.name) + intervals, + key=lambda intv: ( + intv.start, + intv.end, + _is_negative(intv), + _reference_name(intv), + ), ) - def get_enclosing_intervals(self, interval: Interval) -> List[Interval]: + def get_enclosing_intervals(self, interval: LocatableToCompare) -> List[Locatable]: """Returns the set of intervals in this detector that wholly enclose the query interval. i.e. query.start >= target.start and query.end <= target.end. @@ -245,7 +343,7 @@ def get_enclosing_intervals(self, interval: Interval) -> List[Interval]: results = self.get_overlaps(interval) return [i for i in results if interval.start >= i.start and interval.end <= i.end] - def get_enclosed(self, interval: Interval) -> List[Interval]: + def get_enclosed(self, interval: LocatableToCompare) -> List[Locatable]: """Returns the set of intervals in this detector that are enclosed by the query interval. I.e. target.start >= query.start and target.end <= query.end. @@ -267,9 +365,9 @@ def from_bed(cls, path: Path) -> "OverlapDetector": Returns: An overlap detector for the regions in the BED file. """ - detector = OverlapDetector() + detector: OverlapDetector[BedRecord] = OverlapDetector() for region in BedSource(path): - detector.add(Interval.from_bedrecord(region)) + detector.add(region) return detector diff --git a/pybedlite/tests/test_overlap_detector.py b/pybedlite/tests/test_overlap_detector.py index dae34a9..ff38327 100644 --- a/pybedlite/tests/test_overlap_detector.py +++ b/pybedlite/tests/test_overlap_detector.py @@ -1,6 +1,11 @@ """Tests for :py:mod:`~pybedlite.overlap_detector`""" +from dataclasses import dataclass from typing import List +from typing import TypeAlias +from typing import Union + +import pytest from pybedlite.bed_record import BedRecord from pybedlite.bed_record import BedStrand @@ -9,7 +14,7 @@ def run_test(targets: List[Interval], query: Interval, results: List[Interval]) -> None: - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() # Use add_all() to covert itself and add() detector.add_all(intervals=targets) # Test overlaps_any() @@ -113,7 +118,7 @@ def test_get_enclosing_intervals() -> None: d = Interval("1", 15, 19) e = Interval("1", 16, 20) - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() detector.add_all([a, b, c, d, e]) assert detector.get_enclosing_intervals(Interval("1", 10, 100)) == [a] @@ -128,7 +133,7 @@ def test_get_enclosed() -> None: c = Interval("1", 18, 19) d = Interval("1", 50, 99) - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() detector.add_all([a, b, c, d]) assert detector.get_enclosed(Interval("1", 1, 250)) == [a, b, c, d] @@ -145,7 +150,7 @@ def test_iterable() -> None: d = Interval("1", 15, 19) e = Interval("1", 16, 20) - detector = OverlapDetector() + detector: OverlapDetector[Interval] = OverlapDetector() detector.add_all([a]) assert list(detector) == [a] detector.add_all([a, b, c, d, e]) @@ -188,3 +193,63 @@ def test_construction_from_interval(bed_records: List[BedRecord]) -> None: assert new_record.strand is BedStrand.Positive else: assert new_record.strand is record.strand + + +def test_arbitrary_interval_types() -> None: + """ + Test that an overlap detector can receive different interval-like objects and query them too. + """ + + @dataclass(eq=True, frozen=True) + class ChromFeature: + chrom: str + start: int + end: int + + @dataclass(eq=True, frozen=True) + class ContigFeature: + contig: str + start: int + end: int + + @dataclass(eq=True, frozen=True) + class RefnameFeature: + refname: str + start: int + end: int + + # Create minimal features of all supported structural types + chrom_feature = ChromFeature(chrom="chr1", start=0, end=30) + contig_feature = ContigFeature(contig="chr1", start=10, end=40) + refname_feature = RefnameFeature(refname="chr1", start=20, end=50) + + # Setup an overlap detector to hold all the features we care about + AllKinds: TypeAlias = Union[ChromFeature, ContigFeature, RefnameFeature] + features: List[AllKinds] = [chrom_feature, contig_feature, refname_feature] + detector: OverlapDetector[AllKinds] = OverlapDetector(features) + + # Query the overlap detector with yet another type + assert detector.get_overlaps(Interval("chr1", 0, 1)) == [chrom_feature] + assert detector.get_overlaps(Interval("chr1", 25, 26)) == [ + chrom_feature, + contig_feature, + refname_feature, + ] + assert detector.get_overlaps(Interval("chr1", 45, 46)) == [refname_feature] + + +def test_the_overlap_detector_wont_accept_a_non_hashable_feature() -> None: + """ + Test that an overlap detector will not accept a non-hashable feature. + """ + + @dataclass # A dataclass missing both `eq` and `frozen` does not implement __hash__. + class ChromFeature: + chrom: str + start: int + end: int + + detector = OverlapDetector([ChromFeature(chrom="chr1", start=0, end=30)]) + + with pytest.raises(ValueError): + detector.get_overlaps(Interval("chr1", 0, 1))