diff --git a/.github/workflows/test-torchfix.yml b/.github/workflows/test-torchfix.yml index 9a61e01..bb007d1 100644 --- a/.github/workflows/test-torchfix.yml +++ b/.github/workflows/test-torchfix.yml @@ -8,11 +8,15 @@ jobs: test-torchfix: strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest, macos-latest-large, windows-latest] runs-on: ${{ matrix.os }} steps: - name: Checkout uses: actions/checkout@v4 + - name: Show CPU architecture + if: matrix.os == 'macos-latest-large' || matrix.os == 'macos-latest' + run: | + uname -m - uses: actions/setup-python@v5 with: python-version: '3.10' @@ -23,9 +27,12 @@ jobs: - name: Install TorchFix run: | pip3 install ".[dev]" + - name: Run torchfix CLI + run: | + torchfix --help - name: Run pytest run: | - pytest tests + pytest -vv tests - name: Run flake8 run: | flake8 diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 56fd05c..01ee556 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -1,3 +1,5 @@ +import os +import subprocess from pathlib import Path from torchfix.torchfix import ( TorchChecker, @@ -92,3 +94,37 @@ def test_errorcodes_distinct(): def test_parse_error_code_str(case, expected): assert process_error_code_str(case) == expected + + +def test_stderr_suppression(tmp_path): + data = f"import torchvision.datasets as datasets{os.linesep}" + data_path = tmp_path / "fixable.py" + data_path.write_text(data) + result = subprocess.run( + ["torchfix", "--select", "TOR203", "--fix", str(data_path)], + stderr=subprocess.PIPE, + text=True, + check=False, + ) + assert ( + result.stderr == "Finished checking 1 files.\n" + "Transformed 1 files successfully.\n" + ) + + data = f"import torchvision.datasets as datasets{os.linesep}" + data_path = tmp_path / "fixable.py" + data_path.write_text(data) + result = subprocess.run( + ["torchfix", "--select", "TOR203", "--show-stderr", "--fix", str(data_path)], + stderr=subprocess.PIPE, + text=True, + check=False, + ) + expected = result.stderr.replace("\\\\", "\\") + assert ( + expected == f"Executing codemod...\n" + f"Failed to determine module name for {data_path}: '{data_path}' is not in the " + f"subpath of '' OR one path is relative and the other is absolute.\n" + f"Finished checking 1 files.\n" + f"Transformed 1 files successfully.\n" + ) diff --git a/torchfix/__main__.py b/torchfix/__main__.py index eb17658..5400f1b 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -2,7 +2,6 @@ import libcst.codemod as codemod import contextlib -import ctypes import sys import io @@ -17,29 +16,6 @@ from .common import CYAN, ENDC -# Should get rid of this code eventually. -@contextlib.contextmanager -def StderrSilencer(redirect: bool = True): - if not redirect: - yield - elif sys.platform != "darwin": - with contextlib.redirect_stderr(io.StringIO()): - yield - else: - # redirect_stderr does not work for some reason - # Workaround it by using good old dup2 to redirect - # stderr to /dev/null - libc = ctypes.CDLL("libc.dylib") - orig_stderr = libc.dup(2) - with open("/dev/null", "w") as f: - libc.dup2(f.fileno(), 2) - try: - yield - finally: - libc.dup2(orig_stderr, 2) - libc.close(orig_stderr) - - def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() @@ -102,7 +78,12 @@ def main() -> None: command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: - with StderrSilencer(not args.show_stderr): + supress_stderr = ( + contextlib.redirect_stderr(io.StringIO()) + if not args.show_stderr + else contextlib.nullcontext() + ) + with supress_stderr: result = codemod.parallel_exec_transform_with_prettyprint( command_instance, torch_files,