From ba47eba2325e34310164d284b5de1bd71ec1071e Mon Sep 17 00:00:00 2001 From: Nils Homer Date: Tue, 27 Jun 2023 15:31:44 -0700 Subject: [PATCH] feat: make OverlapDetector iterable (#16) --- pybedlite/overlap_detector.py | 13 ++++++++++--- pybedlite/tests/test_overlap_detector.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pybedlite/overlap_detector.py b/pybedlite/overlap_detector.py index 7bc1ce9..0a96d06 100644 --- a/pybedlite/overlap_detector.py +++ b/pybedlite/overlap_detector.py @@ -38,15 +38,18 @@ a set of genomic regions and another genomic region """ -import attr -import cgranges as cr +import itertools from pathlib import Path from typing import Dict from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Set +import attr +import cgranges as cr + from pybedlite.bed_record import BedStrand from pybedlite.bed_source import BedSource @@ -98,7 +101,7 @@ def length(self) -> int: return self.end - self.start -class OverlapDetector: +class OverlapDetector(Iterable[Interval]): """Detects and returns overlaps between a set of genomic regions and another genomic region. Since :class:`~samwell.overlap_detector.Interval` objects are used both to populate the @@ -117,6 +120,10 @@ def __init__(self) -> None: self._refname_to_indexed: Dict[str, bool] = {} self._refname_to_intervals: Dict[str, List[Interval]] = {} + def __iter__(self) -> Iterator[Interval]: + """Iterates over the intervals in the overlap detector.""" + return itertools.chain(*self._refname_to_intervals.values()) + def add(self, interval: Interval) -> None: """Adds an interval to this detector. diff --git a/pybedlite/tests/test_overlap_detector.py b/pybedlite/tests/test_overlap_detector.py index 93a117a..bf9b8a9 100644 --- a/pybedlite/tests/test_overlap_detector.py +++ b/pybedlite/tests/test_overlap_detector.py @@ -134,3 +134,17 @@ def test_get_enclosed() -> None: assert detector.get_enclosed(Interval("1", 16, 20)) == [c] assert detector.get_enclosed(Interval("1", 15, 19)) == [c] assert detector.get_enclosed(Interval("1", 10, 99)) == [b, c, d] + + +def test_iterable() -> None: + a = Interval("1", 1, 250) + b = Interval("1", 5, 30) + c = Interval("1", 10, 99) + d = Interval("1", 15, 19) + e = Interval("1", 16, 20) + + detector = OverlapDetector() + detector.add_all([a]) + assert list(detector) == [a] + detector.add_all([a, b, c, d, e]) + assert list(detector) == [a, a, b, c, d, e]