Skip to content

Commit

Permalink
Add a few type annotations to implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
srittau committed Mar 13, 2024
1 parent 2740a24 commit 222a9a2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
23 changes: 14 additions & 9 deletions asserts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
"""

from __future__ import annotations

import re
import sys
from datetime import datetime, timedelta
from json import loads as json_loads
from typing import Set
from warnings import catch_warnings
from typing import Any, Callable, Set
from warnings import WarningMessage, catch_warnings


def fail(msg=None):
Expand Down Expand Up @@ -864,7 +866,7 @@ def assert_datetime_about_now_utc(actual, msg_fmt="{msg}"):
fail(msg_fmt.format(msg=msg, actual=actual, now=now))


class AssertRaisesContext(object):
class AssertRaisesContext:
"""A context manager to test for exceptions with certain properties.
When the context is left and no exception has been raised, an
Expand Down Expand Up @@ -906,7 +908,7 @@ def __init__(self, exception, msg_fmt="{msg}"):
self._exc_type = exception
self._exc_val = None
self._exception_name = getattr(exception, "__name__", str(exception))
self._tests = []
self._tests: list[Callable[[Any], object]] = []

def __enter__(self):
return self
Expand All @@ -929,7 +931,7 @@ def format_message(self, default_msg):
exc_name=self._exception_name,
)

def add_test(self, cb):
def add_test(self, cb: Callable[[Any], object]) -> None:
"""Add a test callback.
This callback is called after determining that the right exception
Expand Down Expand Up @@ -1188,16 +1190,19 @@ class AssertWarnsContext(object):
def __init__(self, warning_class, msg_fmt="{msg}"):
self._warning_class = warning_class
self._msg_fmt = msg_fmt
self._warning_context = None
self._warning_context: catch_warnings[list[WarningMessage]] | None = (
None
)
self._warnings = []
self._tests = []
self._tests: list[Callable[[Warning], bool]] = []

def __enter__(self):
self._warning_context = catch_warnings(record=True)
self._warnings = self._warning_context.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
assert self._warning_context is not None
self._warning_context.__exit__(exc_type, exc_val, exc_tb)
if not any(self._is_expected_warning(w) for w in self._warnings):
fail(self.format_message())
Expand All @@ -1210,12 +1215,12 @@ def format_message(self):
exc_name=self._warning_class.__name__,
)

def _is_expected_warning(self, warning):
def _is_expected_warning(self, warning) -> bool:
if not issubclass(warning.category, self._warning_class):
return False
return all(test(warning) for test in self._tests)

def add_test(self, cb):
def add_test(self, cb: Callable[[Warning], bool]) -> None:
"""Add a test callback.
This callback is called after determining that the right warning
Expand Down
8 changes: 3 additions & 5 deletions test_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,8 +1269,9 @@ def extra_test(warning):
def test_assert_warns__add_test_not_called(self):
called = Box(False)

def extra_test(_):
def extra_test(_: Warning) -> bool:
called.value = True
return False

with assert_raises(AssertionError):
with assert_warns(UserWarning) as context:
Expand Down Expand Up @@ -1342,10 +1343,7 @@ def test_assert_warns_regex__not_issued__default_message(self):
pass

def test_assert_warns_regex__not_issued__custom_message(self):
expected = (
"no ImportWarning matching 'abc' issued;ImportWarning;"
"ImportWarning;abc"
)
expected = "no ImportWarning matching 'abc' issued;ImportWarning;ImportWarning;abc"
with _assert_raises_assertion(expected):
msg_fmt = "{msg};{exc_type.__name__};{exc_name};{pattern}"
with assert_warns_regex(ImportWarning, r"abc", msg_fmt=msg_fmt):
Expand Down

0 comments on commit 222a9a2

Please sign in to comment.