Skip to content

Commit

Permalink
feat: add a type-savvy .from_path() to BedWriter and BedReader (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
clintval authored May 17, 2024
1 parent 2ceac80 commit a798e10
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 38 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@ pip install bedspec
### Writing

```python
from bedspec import BedWriter, Bed3
from bedspec import Bed3
from bedspec import BedWriter

bed = Bed3("chr1", start=2, end=8)

with BedWriter(open("test.bed", "w")) as writer:
with BedWriter[Bed3].from_path("test.bed") as writer:
writer.write(bed)
```

### Reading

```python
from bedspec import BedReader, Bed3
from bedspec import Bed3
from bedspec import BedReader

with BedReader[Bed3](open("test.bed")) as reader:
with BedReader[Bed3].from_path("test.bed") as reader:
for bed in reader:
print(bed)
```
Expand Down
1 change: 1 addition & 0 deletions bedspec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._bedspec import BedStrand
from ._bedspec import BedType
from ._bedspec import BedWriter
from ._bedspec import Locatable
from ._bedspec import PairBed
from ._bedspec import PointBed
from ._bedspec import SimpleBed
Expand Down
156 changes: 126 additions & 30 deletions bedspec/_bedspec.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,143 @@
import dataclasses
import inspect
import io
import typing
from abc import ABC
from abc import abstractmethod
from dataclasses import asdict as as_dict
from dataclasses import dataclass
from dataclasses import fields
from enum import StrEnum
from enum import unique
from functools import update_wrapper
from pathlib import Path
from types import FrameType
from types import TracebackType
from types import UnionType
from typing import Any
from typing import Callable
from typing import ClassVar
from typing import ContextManager
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import Protocol
from typing import Type
from typing import TypeVar
from typing import Union
from typing import _BaseGenericAlias # type: ignore[attr-defined]
from typing import _GenericAlias # type: ignore[attr-defined]
from typing import cast
from typing import get_args
from typing import get_origin
from typing import get_type_hints
from typing import runtime_checkable

COMMENT_PREFIXES: set[str] = {"#", "browser", "track"}
"""The set of BED comment prefixes supported by this implementation."""

MISSING_FIELD: str = "."
"""The string used to indicate a missing field in a BED record."""

BED_EXTENSION: str = ".bed"
"""The specification defined file extension for BED files."""

BEDPE_EXTENSION: str = ".bedpe"
"""The specification defined file extension for BedPE files."""


def is_union(annotation: Type) -> bool:
"""Test if we have a union type annotation or not."""
return get_origin(annotation) in {Union, UnionType}


def is_optional(annotation: Type) -> bool:
"""Return if this type annotation is optional (a union type with None) or not."""
return is_union(annotation) and type(None) in get_args(annotation)


def singular_non_optional_type(annotation: Type) -> Type:
"""Return the non-optional version of a singular type annotation."""
if not is_optional(annotation):
return annotation

not_none: list[Type] = [arg for arg in get_args(annotation) if arg is not type(None)]
if len(not_none) == 1:
return not_none[0]
else:
raise TypeError(f"Complex non-optional types are not supported! Found: {not_none}")


class MethodType:
def __init__(self, func: Callable, obj: object) -> None:
self.__func__ = func
self.__self__ = obj

def __call__(self, *args: object, **kwargs: object) -> object:
func = self.__func__
obj = self.__self__
return func(obj, *args, **kwargs)


class classmethod_generic:
def __init__(self, f: Callable) -> None:
self.f = f
update_wrapper(self, f)

def __get__(self, obj: object, cls: object | None = None) -> Callable:
if cls is None:
cls = type(obj)
method = MethodType(self.f, cls)
method._generic_classmethod = True # type: ignore[attr-defined]
return method


def __getattr__(self: object, name: str | None = None) -> object:
if hasattr(obj := orig_getattr(self, name), "_generic_classmethod"):
obj.__self__ = self
return obj


orig_getattr = _BaseGenericAlias.__getattr__
_BaseGenericAlias.__getattr__ = __getattr__


@unique
class BedStrand(StrEnum):
"""Valid BED strands for forward, reverse, and unknown directions."""

POSITIVE = "+"
NEGATIVE = "-"
UNKNOWN = MISSING_FIELD

def opposite(self) -> "BedStrand":
"""Return the opposite strand."""
match self:
case BedStrand.POSITIVE:
return BedStrand.NEGATIVE
case BedStrand.NEGATIVE:
return BedStrand.POSITIVE


@dataclass
class BedColor:
"""The color of a BED record in red, green, and blue values."""

def __init__(self, r: int, g: int, b: int):
"""Build a new BED color from red, green, and blue values."""
self.r = r
self.g = g
self.b = b
r: int
g: int
b: int

def __str__(self) -> str:
"""Return a string representation of this BED color."""
return f"{self.r},{self.g},{self.b}"


@runtime_checkable
class DataclassProtocol(Protocol):
"""A protocol for objects that are dataclass instances."""

__dataclass_fields__: ClassVar[dict[str, Any]]


@runtime_checkable
class Locatable(Protocol):
"""A protocol for 0-based half-open objects located on a reference sequence."""

Expand All @@ -66,16 +146,17 @@ class Locatable(Protocol):
end: int


@runtime_checkable
class Stranded(Protocol):
"""A protocol for stranded BED types."""

strand: BedStrand
strand: BedStrand | None


class BedType(ABC, DataclassProtocol):
"""An abstract base class for all types of BED records."""

def __new__(cls, *args: Any, **kwargs: Any) -> "BedType":
def __new__(cls, *args: object, **kwargs: object) -> "BedType":
if not dataclasses.is_dataclass(cls):
raise TypeError("You must mark custom BED records with @dataclass!")
return cast("BedType", object.__new__(cls))
Expand All @@ -84,7 +165,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "BedType":
def decode(cls, line: str) -> "BedType":
"""Decode a line of text into a BED record."""
row: list[str] = line.strip().split()
coerced: dict[str, Any] = {}
coerced: dict[str, object] = {}

try:
zipped = list(zip(fields(cls), row, strict=True))
Expand All @@ -94,9 +175,14 @@ def decode(cls, line: str) -> "BedType":
f" '{' '.join(row)}'"
) from None

hints: dict[str, Type] = get_type_hints(cls)

for field, value in zipped:
try:
coerced[field.name] = field.type(value)
if is_optional(hints[field.name]) and value == MISSING_FIELD:
coerced[field.name] = None
else:
coerced[field.name] = singular_non_optional_type(field.type)(value)
except ValueError:
raise TypeError(
f"Tried to build the BED field '{field.name}' (of type '{field.type.__name__}')"
Expand All @@ -117,6 +203,7 @@ class PointBed(BedType, ABC):
contig: str
start: int

@property
def length(self) -> int:
"""The length of this record."""
return 1
Expand All @@ -138,6 +225,7 @@ def __post_init__(self) -> None:
if self.start >= self.end or self.start < 0:
raise ValueError("start must be greater than 0 and less than end!")

@property
def length(self) -> int:
"""The length of this record."""
return self.end - self.start
Expand Down Expand Up @@ -204,7 +292,7 @@ class Bed4(SimpleBed):
contig: str
start: int
end: int
name: str
name: str | None


@dataclass
Expand All @@ -214,8 +302,8 @@ class Bed5(SimpleBed):
contig: str
start: int
end: int
name: str
score: int
name: str | None
score: int | None


@dataclass
Expand All @@ -225,9 +313,9 @@ class Bed6(SimpleBed, Stranded):
contig: str
start: int
end: int
name: str
score: int
strand: BedStrand
name: str | None
score: int | None
strand: BedStrand | None


# @dataclass
Expand Down Expand Up @@ -260,10 +348,10 @@ class BedPE(PairBed):
contig2: str
start2: int
end2: int
name: str
score: int
strand1: BedStrand
strand2: BedStrand
name: str | None
score: int | None
strand1: BedStrand | None
strand2: BedStrand | None

@property
def bed1(self) -> Bed6:
Expand Down Expand Up @@ -306,9 +394,9 @@ class BedWriter(Generic[BedKind], ContextManager):

bed_kind: type[BedKind] | None

def __class_getitem__(cls, key: Any) -> type:
def __class_getitem__(cls, key: object) -> type:
"""Wrap all objects of this class to become generic aliases."""
return typing._GenericAlias(cls, key) # type: ignore[attr-defined,no-any-return]
return _GenericAlias(cls, key) # type: ignore[no-any-return]

def __new__(cls, handle: io.TextIOWrapper) -> "BedWriter[BedKind]":
"""Bind the kind of BED type to this class for later introspection."""
Expand Down Expand Up @@ -337,6 +425,13 @@ def __exit__(
self.close()
return super().__exit__(__exc_type, __exc_value, __traceback)

@classmethod_generic
def from_path(cls, path: Path | str) -> "BedWriter[BedKind]":
"""Open a BED reader from a file path."""
reader = cls(handle=Path(path).open("w")) # type: ignore[operator]
reader.bed_kind = None if len(cls.__args__) == 0 else cls.__args__[0] # type: ignore[attr-defined]
return cast("BedWriter[BedKind]", reader)

def close(self) -> None:
"""Close the underlying IO handle."""
self._handle.close()
Expand Down Expand Up @@ -390,9 +485,9 @@ class BedReader(Generic[BedKind], ContextManager, Iterable[BedKind]):

bed_kind: type[BedKind] | None

def __class_getitem__(cls, key: Any) -> type:
def __class_getitem__(cls, key: object) -> type:
"""Wrap all objects of this class to become generic aliases."""
return typing._GenericAlias(cls, key) # type: ignore[attr-defined,no-any-return]
return _GenericAlias(cls, key) # type: ignore[no-any-return]

def __new__(cls, handle: io.TextIOWrapper) -> "BedReader[BedKind]":
"""Bind the kind of BED type to this class for later introspection."""
Expand All @@ -413,6 +508,7 @@ def __enter__(self) -> "BedReader[BedKind]":

def __iter__(self) -> Iterator[BedKind]:
"""Iterate through the BED records of this IO handle."""
# TODO: Implement __next__ and type this class as an iterator.
if self.bed_kind is None:
raise NotImplementedError("Untyped reading is not yet supported!")
for line in self._handle:
Expand All @@ -432,12 +528,12 @@ def __exit__(
self.close()
return super().__exit__(__exc_type, __exc_value, __traceback)

@classmethod
def from_path(cls, path: Path | str, bed_kind: type[BedKind]) -> "BedReader[BedKind]":
@classmethod_generic
def from_path(cls, path: Path | str) -> "BedReader[BedKind]":
"""Open a BED reader from a file path."""
reader = cls(handle=Path(path).open())
reader.bed_kind = bed_kind
return reader
reader = cls(handle=Path(path).open()) # type: ignore[operator]
reader.bed_kind = None if len(cls.__args__) == 0 else cls.__args__[0] # type: ignore[attr-defined]
return cast("BedReader[BedKind]", reader)

def close(self) -> None:
"""Close the underlying IO handle."""
Expand Down
Loading

0 comments on commit a798e10

Please sign in to comment.