Skip to content

Commit

Permalink
Allow the OverlapDetector to be generic over input feature types
Browse files Browse the repository at this point in the history
  • Loading branch information
clintval committed Jun 25, 2024
1 parent 3a4b23a commit 5c82e01
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 31 deletions.
152 changes: 125 additions & 27 deletions pybedlite/overlap_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,17 @@
import itertools
from pathlib import Path
from typing import Dict
from typing import Generic
from typing import Hashable
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Protocol
from typing import Set
from typing import Type
from typing import TypeVar
from typing import Union

import attr

Expand All @@ -56,6 +61,76 @@
from pybedlite.bed_source import BedSource


class Span(Protocol):
"""A genomic span with a start and an end. 0-based open-ended."""
@property
def start(self) -> int:
"""A 0-based start position."""

@property
def end(self) -> int:
"""A 0-based open-ended position."""


class LocatableTypeWithChrom(Span, Hashable, Protocol):
"""A genomic feature where reference sequence is accessed with `chrom`."""

@property
def chrom(self) -> str:
"""A reference sequence name."""


class LocatableTypeWithContig(Span, Hashable, Protocol):
"""A genomic feature where reference sequence is accessed with `contig`."""

@property
def contig(self) -> str:
"""A reference sequence name."""


class LocatableTypeWithRefName(Span, Hashable, Protocol):
"""A genomic feature where reference sequence is accessed with `refname`."""

@property
def refname(self) -> str:
"""A reference sequence name."""


Locatable = TypeVar(
"Locatable",
bound=Union[LocatableTypeWithChrom, LocatableTypeWithContig, LocatableTypeWithRefName],
)
"""A genomic feature where reference sequence is accessed with the 3 most common words."""

LocatableToCompare = TypeVar(
"LocatableToCompare",
bound=Union[LocatableTypeWithChrom, LocatableTypeWithContig, LocatableTypeWithRefName],
)
"""A genomic feature where reference sequence is accessed with the 3 most common words."""


def _reference_name(locatable: Locatable) -> str:
"""Return the reference name of a given locatable."""
if hasattr(locatable, "refname"):
return locatable.refname
elif hasattr(locatable, "chrom"):
return locatable.chrom
elif hasattr(locatable, "contig"):
return locatable.contig
else:
raise ValueError(f"Locatable is missing a reference sequence name property: {locatable}")


def _is_negative(locatable: Locatable) -> bool:
"""Attempt to determine if this is a negative stranded locatable or not."""
if hasattr(locatable, "negative"):
return locatable.negative
elif hasattr(locatable, "strand") and isinstance(locatable.strand, BedStrand):
return locatable.strand is BedStrand.Negative
else:
return False


@attr.s(frozen=True, auto_attribs=True)
class Interval:
"""A region mapping to the genome that is 0-based and open-ended
Expand Down Expand Up @@ -126,36 +201,51 @@ def from_bedrecord(cls: Type["Interval"], record: BedRecord) -> "Interval":
)


class OverlapDetector(Iterable[Interval]):
class OverlapDetector(Generic[Locatable], Iterable[Locatable]):
"""Detects and returns overlaps between a set of genomic regions and another genomic region.
Since :class:`~pybedlite.overlap_detector.Interval` objects are used both to populate the
overlap detector and to query it, the coordinate system in use is also 0-based open-ended.
The overlap detector may contain any interval-like Python objects that have the following
properties:
* `chrom` or `contig` or `refname`: The reference sequence name
* `start`: A 0-based start position
* `end`: A 0-based exclusive end position
Interval-like Python objects may also contain strandedness information which will be used
for sorting them in :func:`~pybedlite.overlap_detector.OverlapDetector.get_overlaps` using
either of the following properties if they are present:
* `negative (bool)`: Whether or not the feature is negative stranded or not
* `strand (BedStrand)`: The BED strand of the feature
The same interval may be added multiple times, but only a single instance will be returned
when querying for overlaps. Intervals with the same coordinates but different names are
treated as different intervals.
when querying for overlaps.
This detector is the most efficient when all intervals are added ahead of time.
"""

def __init__(self) -> None:
def __init__(self, intervals: Optional[Iterable[Locatable]] = None) -> None:
# A mapping from the contig/chromosome name to the associated interval tree
self._refname_to_tree: Dict[str, cr.cgranges] = {} # type: ignore
self._refname_to_indexed: Dict[str, bool] = {}
self._refname_to_intervals: Dict[str, List[Interval]] = {}
self._refname_to_intervals: Dict[str, List[Locatable]] = {}
if intervals is not None:
self.add_all(intervals)

def __iter__(self) -> Iterator[Interval]:
def __iter__(self) -> Iterator[Locatable]:
"""Iterates over the intervals in the overlap detector."""
return itertools.chain(*self._refname_to_intervals.values())

def add(self, interval: Interval) -> None:
def add(self, interval: Locatable) -> None:
"""Adds an interval to this detector.
Args:
interval: the interval to add to this detector
"""
refname = interval.refname
if not isinstance(interval, Hashable):
raise ValueError(f"Interval feature is not hashable but should be: {interval}")

refname = _reference_name(interval)
if refname not in self._refname_to_tree:
self._refname_to_tree[refname] = cr.cgranges() # type: ignore
self._refname_to_indexed[refname] = False
Expand All @@ -168,13 +258,13 @@ def add(self, interval: Interval) -> None:

# Add the interval to the tree
tree = self._refname_to_tree[refname]
tree.add(interval.refname, interval.start, interval.end, interval_idx)
tree.add(refname, interval.start, interval.end, interval_idx)

# Flag this tree as needing to be indexed after adding a new interval, but defer
# indexing
self._refname_to_indexed[refname] = False

def add_all(self, intervals: Iterable[Interval]) -> None:
def add_all(self, intervals: Iterable[Locatable]) -> None:
"""Adds one or more intervals to this detector.
Args:
Expand All @@ -183,7 +273,7 @@ def add_all(self, intervals: Iterable[Interval]) -> None:
for interval in intervals:
self.add(interval)

def overlaps_any(self, interval: Interval) -> bool:
def overlaps_any(self, interval: LocatableToCompare) -> bool:
"""Determines whether the given interval overlaps any interval in this detector.
Args:
Expand All @@ -193,20 +283,21 @@ def overlaps_any(self, interval: Interval) -> bool:
True if and only if the given interval overlaps with any interval in this
detector.
"""
tree = self._refname_to_tree.get(interval.refname)
refname = _reference_name(interval)
tree = self._refname_to_tree.get(refname)
if tree is None:
return False
else:
if not self._refname_to_indexed[interval.refname]:
if not self._refname_to_indexed[refname]:
tree.index()
try:
next(iter(tree.overlap(interval.refname, interval.start, interval.end)))
next(iter(tree.overlap(refname, interval.start, interval.end)))
except StopIteration:
return False
else:
return True

def get_overlaps(self, interval: Interval) -> List[Interval]:
def get_overlaps(self, interval: LocatableToCompare) -> List[Locatable]:
"""Returns any intervals in this detector that overlap the given interval.
Args:
Expand All @@ -216,23 +307,30 @@ def get_overlaps(self, interval: Interval) -> List[Interval]:
The list of intervals in this detector that overlap the given interval, or the empty
list if no overlaps exist. The intervals will be return in ascending genomic order.
"""
tree = self._refname_to_tree.get(interval.refname)
refname = _reference_name(interval)
tree = self._refname_to_tree.get(refname)
if tree is None:
return []
else:
if not self._refname_to_indexed[interval.refname]:
if not self._refname_to_indexed[refname]:
tree.index()
ref_intervals: List[Interval] = self._refname_to_intervals[interval.refname]
ref_intervals: List[Locatable] = self._refname_to_intervals[refname]
# NB: only return unique instances of intervals
intervals: Set[Interval] = {
intervals: Set[Locatable] = {
ref_intervals[index]
for _, _, index in tree.overlap(interval.refname, interval.start, interval.end)
for _, _, index in tree.overlap(refname, interval.start, interval.end)
}
return sorted(
intervals, key=lambda intv: (intv.start, intv.end, intv.negative, intv.name)
intervals,
key=lambda intv: (
intv.start,
intv.end,
_is_negative(intv),
_reference_name(intv),
),
)

def get_enclosing_intervals(self, interval: Interval) -> List[Interval]:
def get_enclosing_intervals(self, interval: LocatableToCompare) -> List[Locatable]:
"""Returns the set of intervals in this detector that wholly enclose the query interval.
i.e. query.start >= target.start and query.end <= target.end.
Expand All @@ -245,7 +343,7 @@ def get_enclosing_intervals(self, interval: Interval) -> List[Interval]:
results = self.get_overlaps(interval)
return [i for i in results if interval.start >= i.start and interval.end <= i.end]

def get_enclosed(self, interval: Interval) -> List[Interval]:
def get_enclosed(self, interval: LocatableToCompare) -> List[Locatable]:
"""Returns the set of intervals in this detector that are enclosed by the query
interval. I.e. target.start >= query.start and target.end <= query.end.
Expand All @@ -267,9 +365,9 @@ def from_bed(cls, path: Path) -> "OverlapDetector":
Returns:
An overlap detector for the regions in the BED file.
"""
detector = OverlapDetector()
detector: OverlapDetector[BedRecord] = OverlapDetector()

for region in BedSource(path):
detector.add(Interval.from_bedrecord(region))
detector.add(region)

return detector
73 changes: 69 additions & 4 deletions pybedlite/tests/test_overlap_detector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Tests for :py:mod:`~pybedlite.overlap_detector`"""

from dataclasses import dataclass
from typing import List
from typing import TypeAlias
from typing import Union

import pytest

from pybedlite.bed_record import BedRecord
from pybedlite.bed_record import BedStrand
Expand All @@ -9,7 +14,7 @@


def run_test(targets: List[Interval], query: Interval, results: List[Interval]) -> None:
detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
# Use add_all() to covert itself and add()
detector.add_all(intervals=targets)
# Test overlaps_any()
Expand Down Expand Up @@ -113,7 +118,7 @@ def test_get_enclosing_intervals() -> None:
d = Interval("1", 15, 19)
e = Interval("1", 16, 20)

detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
detector.add_all([a, b, c, d, e])

assert detector.get_enclosing_intervals(Interval("1", 10, 100)) == [a]
Expand All @@ -128,7 +133,7 @@ def test_get_enclosed() -> None:
c = Interval("1", 18, 19)
d = Interval("1", 50, 99)

detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
detector.add_all([a, b, c, d])

assert detector.get_enclosed(Interval("1", 1, 250)) == [a, b, c, d]
Expand All @@ -145,7 +150,7 @@ def test_iterable() -> None:
d = Interval("1", 15, 19)
e = Interval("1", 16, 20)

detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
detector.add_all([a])
assert list(detector) == [a]
detector.add_all([a, b, c, d, e])
Expand Down Expand Up @@ -188,3 +193,63 @@ def test_construction_from_interval(bed_records: List[BedRecord]) -> None:
assert new_record.strand is BedStrand.Positive
else:
assert new_record.strand is record.strand


def test_arbitrary_interval_types() -> None:
"""
Test that an overlap detector can receive different interval-like objects and query them too.
"""

@dataclass(eq=True, frozen=True)
class ChromFeature:
chrom: str
start: int
end: int

@dataclass(eq=True, frozen=True)
class ContigFeature:
contig: str
start: int
end: int

@dataclass(eq=True, frozen=True)
class RefnameFeature:
refname: str
start: int
end: int

# Create minimal features of all supported structural types
chrom_feature = ChromFeature(chrom="chr1", start=0, end=30)
contig_feature = ContigFeature(contig="chr1", start=10, end=40)
refname_feature = RefnameFeature(refname="chr1", start=20, end=50)

# Setup an overlap detector to hold all the features we care about
AllKinds: TypeAlias = Union[ChromFeature, ContigFeature, RefnameFeature]
features: List[AllKinds] = [chrom_feature, contig_feature, refname_feature]
detector: OverlapDetector[AllKinds] = OverlapDetector(features)

# Query the overlap detector with yet another type
assert detector.get_overlaps(Interval("chr1", 0, 1)) == [chrom_feature]
assert detector.get_overlaps(Interval("chr1", 25, 26)) == [
chrom_feature,
contig_feature,
refname_feature,
]
assert detector.get_overlaps(Interval("chr1", 45, 46)) == [refname_feature]


def test_the_overlap_detector_wont_accept_a_non_hashable_feature() -> None:
"""
Test that an overlap detector will not accept a non-hashable feature.
"""

@dataclass # A dataclass missing both `eq` and `frozen` does not implement __hash__.
class ChromFeature:
chrom: str
start: int
end: int

detector = OverlapDetector([ChromFeature(chrom="chr1", start=0, end=30)])

with pytest.raises(ValueError):
detector.get_overlaps(Interval("chr1", 0, 1))

0 comments on commit 5c82e01

Please sign in to comment.