Skip to content

Commit

Permalink
Add types to run command
Browse files Browse the repository at this point in the history
  • Loading branch information
ssbarnea committed Dec 9, 2024
1 parent 914ec1a commit 3670710
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ disable = [
# Disabled on purpose:
"line-too-long", # covered by black
"protected-access", # covered by ruff SLF001
"redefined-builtin", # covered by ruff
"too-many-branches", # covered by ruff C901
"wrong-import-order", # covered by ruff
# TODO(ssbarnea): remove temporary skips adding during initial adoption:
Expand Down Expand Up @@ -195,6 +196,7 @@ exclude = [
".tox",
"build",
"venv",
"src/subprocess_tee/_version.py"
"src/subprocess_tee/_version.py",
"src/subprocess_tee/_types.py"
]
paths = ["src", "test"]
40 changes: 36 additions & 4 deletions src/subprocess_tee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""tee like run implementation."""

# cspell: ignore popenargs

from __future__ import annotations

import asyncio
Expand All @@ -21,7 +23,9 @@
__all__ = ["CompletedProcess", "__version__", "run"]

if TYPE_CHECKING:
CompletedProcess = subprocess.CompletedProcess[Any] # pylint: disable=E1136
from subprocess_tee._types import SequenceNotStr

CompletedProcess = subprocess.CompletedProcess[Any]
from collections.abc import Callable
else:
CompletedProcess = subprocess.CompletedProcess
Expand All @@ -39,7 +43,7 @@ async def _read_stream(stream: StreamReader, callback: Callable[..., Any]) -> No


async def _stream_subprocess( # noqa: C901
args: str | list[str],
args: str | tuple[str, ...],
**kwargs: Any,
) -> CompletedProcess:
platform_settings: dict[str, Any] = {}
Expand Down Expand Up @@ -136,7 +140,16 @@ def tee_func(line: bytes, sink: list[str], pipe: Any | None) -> None:
)


def run(args: str | list[str], **kwargs: Any) -> CompletedProcess:
# signature is based on stdlib
# subprocess.run()
def run(
*popenargs: str | SequenceNotStr[str],
input: bytes | str | None = None, # noqa: A002
capture_output: bool = False,
timeout: int | None = None,
check: bool = False,
**kwargs: Any,
) -> CompletedProcess:
"""Drop-in replacement for subprocess.run that behaves like tee.
Extra arguments added by our version:
Expand All @@ -148,12 +161,31 @@ def run(args: str | list[str], **kwargs: Any) -> CompletedProcess:
Raises:
CalledProcessError: ...
TypeError: ...
"""
# run was called with a list instead of a single item but asyncio
# create_subprocess_shell requires command as a single string, so
# we need to convert it to string
cmd = args if isinstance(args, str) else join(args)
# breakpoint()
# if len(popenargs) == 1 and isinstance(popenargs[0], list):
# cmd = join(popenargs[0])
# else:
args: str | tuple[str, ...]
if len(popenargs) == 0:
args = ()
else:
if not isinstance(popenargs, str): # make mypy/pyright happy
raise TypeError(popenargs)
_ = popenargs[0]
if not isinstance(_, str): # make mypy/pyright happy
raise TypeError(_)
args = _
cmd = popenargs if isinstance(popenargs, str) else join(args)
kwargs["check"] = check
kwargs["input"] = input
kwargs["timeout"] = timeout
kwargs["capture_output"] = capture_output

check = kwargs.get("check", False)

Expand Down
25 changes: 25 additions & 0 deletions src/subprocess_tee/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Internally used types."""

# Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
from collections.abc import Iterator, Sequence
from typing import Any, Protocol, SupportsIndex, TypeVar, overload

_T_co = TypeVar("_T_co", covariant=True)


class SequenceNotStr(Protocol[_T_co]):
"""Lists of strings which are not strings themselves."""

@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
def __contains__(self, value: object, /) -> bool: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T_co]: ...
def index( # pylint: disable=C0116
self, value: Any, start: int = 0, stop: int = ..., /
) -> int: ...
def count(self, value: Any, /) -> int: ... # pylint: disable=C0116

def __reversed__(self) -> Iterator[_T_co]: ...

0 comments on commit 3670710

Please sign in to comment.