Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the OverlapDetector to be generic over input feature types #34

Merged
merged 18 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 144 additions & 27 deletions pybedlite/overlap_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,17 @@
import itertools
from pathlib import Path
from typing import Dict
from typing import Generic
from typing import Hashable
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Protocol
from typing import Set
from typing import Type
from typing import TypeVar
from typing import Union

import attr

Expand All @@ -56,6 +61,52 @@
from pybedlite.bed_source import BedSource


class _Span(Protocol):
"""A span with a start and an end. 0-based open-ended."""

@property
def start(self) -> int:
"""A 0-based start position."""

@property
def end(self) -> int:
"""A 0-based open-ended position."""


class _GenomicSpanWithChrom(_Span, Hashable, Protocol):
"""A genomic feature where reference sequence is accessed with `chrom`."""

@property
def chrom(self) -> str:
"""A reference sequence name."""


class _GenomicSpanWithContig(_Span, Hashable, Protocol):
"""A genomic feature where reference sequence is accessed with `contig`."""

@property
def contig(self) -> str:
"""A reference sequence name."""


class _GenomicSpanWithRefName(_Span, Hashable, Protocol):
"""A genomic feature where reference sequence is accessed with `refname`."""

@property
def refname(self) -> str:
"""A reference sequence name."""


GenomicSpan = TypeVar(
"GenomicSpan",
bound=Union[_GenomicSpanWithChrom, _GenomicSpanWithContig, _GenomicSpanWithRefName],
)
"""
A genomic feature where the reference sequence name is accessed with any of the 3 most common
property names ("chrom", "contig", "refname").
"""


@attr.s(frozen=True, auto_attribs=True)
class Interval:
"""A region mapping to the genome that is 0-based and open-ended
Expand Down Expand Up @@ -126,36 +177,94 @@ def from_bedrecord(cls: Type["Interval"], record: BedRecord) -> "Interval":
)


class OverlapDetector(Iterable[Interval]):
_GenericGenomicSpan = TypeVar(
"_GenericGenomicSpan",
bound=Union[_GenomicSpanWithChrom, _GenomicSpanWithContig, _GenomicSpanWithRefName],
)
"""
A generic genomic feature where the reference sequence name is accessed with any of the 3 most
common property names ("chrom", "contig", "refname"). This type variable is used for describing the
generic type contained within the :class:`~pybedlite.overlap_detector.OverlapDetector`.
"""


class OverlapDetector(Generic[_GenericGenomicSpan], Iterable[_GenericGenomicSpan]):
"""Detects and returns overlaps between a set of genomic regions and another genomic region.

Since :class:`~pybedlite.overlap_detector.Interval` objects are used both to populate the
overlap detector and to query it, the coordinate system in use is also 0-based open-ended.
The overlap detector may contain any interval-like Python objects that have the following
properties:

* `chrom` or `contig` or `refname`: The reference sequence name
* `start`: A 0-based start position
* `end`: A 0-based exclusive end position

Interval-like Python objects may also contain strandedness information which will be used
for sorting them in :func:`~pybedlite.overlap_detector.OverlapDetector.get_overlaps` using
either of the following properties if they are present:

* `negative (bool)`: Whether or not the feature is negative stranded or not
* `strand (BedStrand)`: The BED strand of the feature
* `strand (str)`: The strand of the feature (`"-"` for negative)

The same interval may be added multiple times, but only a single instance will be returned
when querying for overlaps. Intervals with the same coordinates but different names are
treated as different intervals.
when querying for overlaps.

This detector is the most efficient when all intervals are added ahead of time.
"""

def __init__(self) -> None:
def __init__(self, intervals: Optional[Iterable[_GenericGenomicSpan]] = None) -> None:
clintval marked this conversation as resolved.
Show resolved Hide resolved
# A mapping from the contig/chromosome name to the associated interval tree
self._refname_to_tree: Dict[str, cr.cgranges] = {} # type: ignore
self._refname_to_indexed: Dict[str, bool] = {}
self._refname_to_intervals: Dict[str, List[Interval]] = {}
self._refname_to_intervals: Dict[str, List[_GenericGenomicSpan]] = {}
if intervals is not None:
self.add_all(intervals)
clintval marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self) -> Iterator[Interval]:
def __iter__(self) -> Iterator[_GenericGenomicSpan]:
"""Iterates over the intervals in the overlap detector."""
return itertools.chain(*self._refname_to_intervals.values())

def add(self, interval: Interval) -> None:
@staticmethod
def _reference_sequence_name(interval: GenomicSpan) -> str:
"""Return the reference name of a given interval."""
if isinstance(interval, Interval) or hasattr(interval, "refname"):
return interval.refname
elif isinstance(interval, BedRecord) or hasattr(interval, "chrom"):
return interval.chrom
elif hasattr(interval, "contig"):
return interval.contig
clintval marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(
f"Genomic span is missing a reference sequence name property: {interval}"
)

@staticmethod
def _is_negative(interval: GenomicSpan) -> bool:
"""Determine if this is a negative stranded interval or not."""
return (
(hasattr(interval, "negative") and interval.negative)
or (
hasattr(interval, "strand")
and isinstance(interval.strand, BedStrand)
and interval.strand is BedStrand.Negative
)
or (
hasattr(interval, "strand")
and isinstance(interval.strand, str)
and interval.strand == "-"
)
)

def add(self, interval: _GenericGenomicSpan) -> None:
"""Adds an interval to this detector.

Args:
interval: the interval to add to this detector
"""
refname = interval.refname
if not isinstance(interval, Hashable):
raise ValueError(f"Interval feature is not hashable but should be: {interval}")
clintval marked this conversation as resolved.
Show resolved Hide resolved

refname = self._reference_sequence_name(interval)
if refname not in self._refname_to_tree:
self._refname_to_tree[refname] = cr.cgranges() # type: ignore
self._refname_to_indexed[refname] = False
Expand All @@ -168,13 +277,13 @@ def add(self, interval: Interval) -> None:

# Add the interval to the tree
tree = self._refname_to_tree[refname]
tree.add(interval.refname, interval.start, interval.end, interval_idx)
tree.add(refname, interval.start, interval.end, interval_idx)

# Flag this tree as needing to be indexed after adding a new interval, but defer
# indexing
self._refname_to_indexed[refname] = False

def add_all(self, intervals: Iterable[Interval]) -> None:
def add_all(self, intervals: Iterable[_GenericGenomicSpan]) -> None:
"""Adds one or more intervals to this detector.

Args:
Expand All @@ -183,7 +292,7 @@ def add_all(self, intervals: Iterable[Interval]) -> None:
for interval in intervals:
self.add(interval)

def overlaps_any(self, interval: Interval) -> bool:
def overlaps_any(self, interval: GenomicSpan) -> bool:
msto marked this conversation as resolved.
Show resolved Hide resolved
"""Determines whether the given interval overlaps any interval in this detector.

Args:
Expand All @@ -193,20 +302,21 @@ def overlaps_any(self, interval: Interval) -> bool:
True if and only if the given interval overlaps with any interval in this
detector.
"""
tree = self._refname_to_tree.get(interval.refname)
refname = self._reference_sequence_name(interval)
tree = self._refname_to_tree.get(refname)
if tree is None:
return False
else:
if not self._refname_to_indexed[interval.refname]:
if not self._refname_to_indexed[refname]:
tree.index()
try:
next(iter(tree.overlap(interval.refname, interval.start, interval.end)))
next(iter(tree.overlap(refname, interval.start, interval.end)))
except StopIteration:
return False
else:
return True

def get_overlaps(self, interval: Interval) -> List[Interval]:
def get_overlaps(self, interval: GenomicSpan) -> List[_GenericGenomicSpan]:
clintval marked this conversation as resolved.
Show resolved Hide resolved
"""Returns any intervals in this detector that overlap the given interval.

Args:
Expand All @@ -216,23 +326,30 @@ def get_overlaps(self, interval: Interval) -> List[Interval]:
The list of intervals in this detector that overlap the given interval, or the empty
list if no overlaps exist. The intervals will be return in ascending genomic order.
"""
tree = self._refname_to_tree.get(interval.refname)
refname = self._reference_sequence_name(interval)
tree = self._refname_to_tree.get(refname)
if tree is None:
return []
else:
if not self._refname_to_indexed[interval.refname]:
if not self._refname_to_indexed[refname]:
tree.index()
ref_intervals: List[Interval] = self._refname_to_intervals[interval.refname]
ref_intervals: List[_GenericGenomicSpan] = self._refname_to_intervals[refname]
# NB: only return unique instances of intervals
intervals: Set[Interval] = {
intervals: Set[_GenericGenomicSpan] = {
ref_intervals[index]
for _, _, index in tree.overlap(interval.refname, interval.start, interval.end)
for _, _, index in tree.overlap(refname, interval.start, interval.end)
}
return sorted(
intervals, key=lambda intv: (intv.start, intv.end, intv.negative, intv.name)
intervals,
key=lambda intv: (
intv.start,
intv.end,
self._is_negative(intv),
self._reference_sequence_name(intv),
),
)

def get_enclosing_intervals(self, interval: Interval) -> List[Interval]:
def get_enclosing_intervals(self, interval: GenomicSpan) -> List[_GenericGenomicSpan]:
"""Returns the set of intervals in this detector that wholly enclose the query interval.
i.e. query.start >= target.start and query.end <= target.end.

Expand All @@ -245,7 +362,7 @@ def get_enclosing_intervals(self, interval: Interval) -> List[Interval]:
results = self.get_overlaps(interval)
return [i for i in results if interval.start >= i.start and interval.end <= i.end]

def get_enclosed(self, interval: Interval) -> List[Interval]:
def get_enclosed(self, interval: GenomicSpan) -> List[_GenericGenomicSpan]:
"""Returns the set of intervals in this detector that are enclosed by the query
interval. I.e. target.start >= query.start and target.end <= query.end.

Expand All @@ -267,9 +384,9 @@ def from_bed(cls, path: Path) -> "OverlapDetector":
Returns:
An overlap detector for the regions in the BED file.
"""
detector = OverlapDetector()
detector: OverlapDetector[BedRecord] = OverlapDetector()

for region in BedSource(path):
detector.add(Interval.from_bedrecord(region))
detector.add(region)
clintval marked this conversation as resolved.
Show resolved Hide resolved

return detector
71 changes: 67 additions & 4 deletions pybedlite/tests/test_overlap_detector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Tests for :py:mod:`~pybedlite.overlap_detector`"""

from dataclasses import dataclass
from typing import List
from typing import Union

import pytest
from typing_extensions import TypeAlias

from pybedlite.bed_record import BedRecord
from pybedlite.bed_record import BedStrand
Expand All @@ -9,7 +14,7 @@


def run_test(targets: List[Interval], query: Interval, results: List[Interval]) -> None:
detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
# Use add_all() to covert itself and add()
detector.add_all(intervals=targets)
# Test overlaps_any()
Expand Down Expand Up @@ -113,7 +118,7 @@ def test_get_enclosing_intervals() -> None:
d = Interval("1", 15, 19)
e = Interval("1", 16, 20)

detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
detector.add_all([a, b, c, d, e])

assert detector.get_enclosing_intervals(Interval("1", 10, 100)) == [a]
Expand All @@ -128,7 +133,7 @@ def test_get_enclosed() -> None:
c = Interval("1", 18, 19)
d = Interval("1", 50, 99)

detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
detector.add_all([a, b, c, d])

assert detector.get_enclosed(Interval("1", 1, 250)) == [a, b, c, d]
Expand All @@ -145,7 +150,7 @@ def test_iterable() -> None:
d = Interval("1", 15, 19)
e = Interval("1", 16, 20)

detector = OverlapDetector()
detector: OverlapDetector[Interval] = OverlapDetector()
detector.add_all([a])
assert list(detector) == [a]
detector.add_all([a, b, c, d, e])
Expand Down Expand Up @@ -188,3 +193,61 @@ def test_construction_from_interval(bed_records: List[BedRecord]) -> None:
assert new_record.strand is BedStrand.Positive
else:
assert new_record.strand is record.strand


def test_arbitrary_interval_types() -> None:
"""
Test that an overlap detector can receive different interval-like objects and query them too.
"""

@dataclass(eq=True, frozen=True)
class ChromFeature:
chrom: str
start: int
end: int

@dataclass(eq=True, frozen=True)
class ContigFeature:
contig: str
start: int
end: int

@dataclass(eq=True, frozen=True)
class RefnameFeature:
refname: str
start: int
end: int

# Create minimal features of all supported structural types
chrom_feature = ChromFeature(chrom="chr1", start=0, end=30)
contig_feature = ContigFeature(contig="chr1", start=10, end=40)
refname_feature = RefnameFeature(refname="chr1", start=20, end=50)

# Setup an overlap detector to hold all the features we care about
AllKinds: TypeAlias = Union[ChromFeature, ContigFeature, RefnameFeature]
features: List[AllKinds] = [chrom_feature, contig_feature, refname_feature]
detector: OverlapDetector[AllKinds] = OverlapDetector(features)

# Query the overlap detector with yet another type
assert detector.get_overlaps(Interval("chr1", 0, 1)) == [chrom_feature]
assert detector.get_overlaps(Interval("chr1", 25, 26)) == [
chrom_feature,
contig_feature,
refname_feature,
]
assert detector.get_overlaps(Interval("chr1", 45, 46)) == [refname_feature]


def test_the_overlap_detector_wont_accept_a_non_hashable_feature() -> None:
"""
Test that an overlap detector will not accept a non-hashable feature.
"""

@dataclass # A dataclass missing both `eq` and `frozen` does not implement __hash__.
class ChromFeature:
chrom: str
start: int
end: int

with pytest.raises(ValueError):
OverlapDetector([ChromFeature(chrom="chr1", start=0, end=30)])
Loading