Skip to content

Commit

Permalink
refactor(wrap_stdio): remake classes
Browse files Browse the repository at this point in the history
  • Loading branch information
saygox committed Nov 22, 2021
1 parent a5860a8 commit 3a15639
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 30 deletions.
36 changes: 22 additions & 14 deletions commitizen/wrap_stdio_linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,50 @@

if sys.platform == "linux": # pragma: no cover
import os
from io import IOBase

class WrapStdioLinux:
def __init__(self, stdx: IOBase):
self._fileno = stdx.fileno()
if self._fileno == 0:
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
tty = open(fd, "wb+", buffering=0)
else:
tty = open("/dev/tty", "w") # type: ignore

# from io import IOBase

class WrapStdinLinux:
def __init__(self):
fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
tty = open(fd, "wb+", buffering=0)
self.tty = tty

def __getattr__(self, key):
if key == "encoding" and self._fileno == 0:
if key == "encoding":
return "UTF-8"
return getattr(self.tty, key)

def __del__(self):
self.tty.close()

class WrapStdoutLinux:
def __init__(self):
tty = open("/dev/tty", "w")
self.tty = tty

def __getattr__(self, key):
return getattr(self.tty, key)

def __del__(self):
self.tty.close()

backup_stdin = None
backup_stdout = None
backup_stderr = None

def _wrap_stdio():
global backup_stdin
backup_stdin = sys.stdin
sys.stdin = WrapStdioLinux(sys.stdin)
sys.stdin = WrapStdinLinux()

global backup_stdout
backup_stdout = sys.stdout
sys.stdout = WrapStdioLinux(sys.stdout)
sys.stdout = WrapStdoutLinux()

global backup_stderr
backup_stderr = sys.stderr
sys.stderr = WrapStdioLinux(sys.stderr)
sys.stderr = WrapStdoutLinux()

def _unwrap_stdio():
global backup_stdin
Expand Down
12 changes: 6 additions & 6 deletions commitizen/wrap_stdio_unix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@
import selectors
from asyncio import (
DefaultEventLoopPolicy,
SelectorEventLoop,
get_event_loop_policy,
set_event_loop_policy,
)
from io import IOBase

class CZEventLoopPolicy(DefaultEventLoopPolicy): # pragma: no cover
def get_event_loop(self):
self.set_event_loop(self._loop_factory(selectors.SelectSelector()))
return self._local._loop

class WrapStdioUnix:
def __init__(self, stdx: IOBase):
self._fileno = stdx.fileno()
Expand All @@ -33,6 +29,7 @@ def __getattr__(self, key):
def __del__(self):
self.tty.close()

# backup_event_loop = None
backup_event_loop_policy = None
backup_stdin = None
backup_stdout = None
Expand All @@ -41,7 +38,10 @@ def __del__(self):
def _wrap_stdio():
global backup_event_loop_policy
backup_event_loop_policy = get_event_loop_policy()
set_event_loop_policy(CZEventLoopPolicy())

event_loop = DefaultEventLoopPolicy()
event_loop.set_event_loop(SelectorEventLoop(selectors.SelectSelector()))
set_event_loop_policy(event_loop)

global backup_stdin
backup_stdin = sys.stdin
Expand Down
30 changes: 20 additions & 10 deletions tests/test_wrap_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,52 @@ def test_warp_stdio_exists():
if sys.platform == "win32": # pragma: no cover
pass
elif sys.platform == "linux":
from commitizen.wrap_stdio_linux import WrapStdioLinux
from commitizen.wrap_stdio_linux import WrapStdinLinux, WrapStdoutLinux

def test_wrap_stdio_linux(mocker):
def test_wrap_stdin_linux(mocker):

tmp_stdin = sys.stdin
tmp_stdout = sys.stdout
tmp_stderr = sys.stderr

mocker.patch("os.open")
readerwriter_mock = mocker.mock_open(read_data="data")
mocker.patch("builtins.open", readerwriter_mock, create=True)

mocker.patch.object(sys.stdin, "fileno", return_value=0)
mocker.patch.object(sys.stdout, "fileno", return_value=1)
mocker.patch.object(sys.stdout, "fileno", return_value=2)

wrap_stdio.wrap_stdio()

assert sys.stdin != tmp_stdin
assert isinstance(sys.stdin, WrapStdioLinux)
assert isinstance(sys.stdin, WrapStdinLinux)
assert sys.stdin.encoding == "UTF-8"
assert sys.stdin.read() == "data"

wrap_stdio.unwrap_stdio()

assert sys.stdin == tmp_stdin

def test_wrap_stdout_linux(mocker):

tmp_stdout = sys.stdout
tmp_stderr = sys.stderr

mocker.patch("os.open")
readerwriter_mock = mocker.mock_open(read_data="data")
mocker.patch("builtins.open", readerwriter_mock, create=True)

wrap_stdio.wrap_stdio()

assert sys.stdout != tmp_stdout
assert isinstance(sys.stdout, WrapStdioLinux)
assert isinstance(sys.stdout, WrapStdoutLinux)
sys.stdout.write("stdout")
readerwriter_mock().write.assert_called_with("stdout")

assert sys.stderr != tmp_stderr
assert isinstance(sys.stderr, WrapStdioLinux)
assert isinstance(sys.stderr, WrapStdoutLinux)
sys.stdout.write("stderr")
readerwriter_mock().write.assert_called_with("stderr")

wrap_stdio.unwrap_stdio()

assert sys.stdin == tmp_stdin
assert sys.stdout == tmp_stdout
assert sys.stderr == tmp_stderr

Expand Down

0 comments on commit 3a15639

Please sign in to comment.