From baa18d9bde18ea01b6cd9a578d96efca537dc425 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 30 Dec 2023 17:04:28 +0700 Subject: [PATCH 1/2] Fix pyproject.toml; Added a test --- pyproject.toml | 5 +++++ tests/test_types.py | 14 +++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 85bf3fd..791addc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,11 @@ contextvars = {version = "*", python = "~3.6"} [tool.poetry.dev-dependencies] typing_extensions = "*" +# The following are used for benchmarking - +pytest-benchmark = "*" +beartype = "*" +plum-dispatch = "*" +multipledispatch = "*" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/test_types.py b/tests/test_types.py index e23f5ba..2bc6f7b 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -5,7 +5,7 @@ import collections.abc as cabc from runtype.base_types import DataType, ContainerType, PhantomType -from runtype.pytypes import type_caster, List, Dict, Int, Any, Constraint, String, Tuple, Iter, Literal +from runtype.pytypes import type_caster, List, Dict, Int, Any, Constraint, String, Tuple, Iter, Literal, NoneType from runtype.typesystem import TypeSystem @@ -60,7 +60,7 @@ def test_phantom(self): - def test_pytypes(self): + def test_pytypes1(self): assert List + Dict == Dict + List assert Any + ((Any + Any) + Any) is Any @@ -133,7 +133,7 @@ def get_type(self, a): assert i.isinstance(3, 4) assert not i.isinstance(4, 3) - def test_pytypes(self): + def test_pytypes2(self): assert Tuple <= Tuple assert Tuple >= Tuple # assert Tuple[int] <= Tuple @@ -231,8 +231,12 @@ def test_canonize_pytypes(self): for b in bad: assert not t.test_instance(b), (t, b) - - + def test_any(self): + assert Any <= Any + assert Any <= Any + Int + assert Any <= Any + NoneType + assert Any + Int <= Any + assert Any + NoneType <= Any From d89cf87998f18c1ad56c78ec30be7fc9b234bb90 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sat, 30 Dec 2023 17:13:29 +0700 Subject: [PATCH 2/2] Fixes to typesystem, with tests --- runtype/base_types.py | 22 +++++++++++++++++++++- tests/test_basic.py | 14 +++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/runtype/base_types.py b/runtype/base_types.py index 2c00310..59ee9d2 100644 --- a/runtype/base_types.py +++ b/runtype/base_types.py @@ -50,7 +50,8 @@ def __ge__(self, other): def __le__(self, other): if isinstance(other, Type): - return other is self + if other is self: # Optimization + return True return NotImplemented @@ -72,6 +73,12 @@ def __le__(self, other): return super().__le__(other) + def __ge__(self, other): + # XXX hack + if isinstance(other, AnyType): + return False + return NotImplemented + class SumType(Type): """Implements a sum type, i.e. a disjoint union of a set of types. @@ -170,6 +177,11 @@ class ContainerType(DataType): def __getitem__(self, other): return GenericType(self, other) + def __le__(self, other): + # XXX hack + if isinstance(other, AnyType): + return True + return super().__le__(other) class GenericType(ContainerType): """Implements a generic type. i.e. a container for items of a specific type. @@ -214,6 +226,10 @@ def __le__(self, other): elif isinstance(other, DataType): return self.base <= other + elif isinstance(other, AnyType): + # HACK + return True + return NotImplemented def __ge__(self, other): @@ -223,6 +239,10 @@ def __ge__(self, other): elif isinstance(other, DataType): return self.base >= other + elif isinstance(other, AnyType): + # HACK + return False + return NotImplemented def __hash__(self): diff --git a/tests/test_basic.py b/tests/test_basic.py index 9b91229..9ba8b8b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1035,9 +1035,17 @@ class Foo: "d": {"a": {"baz": 2}} } - - - + def test_any(self): + assert is_subtype(int, Any) + assert not is_subtype(Any, int) + assert is_subtype(Any, Any) + assert is_subtype(Any, Union[Any, int]) + assert is_subtype(Any, Union[Any, None]) + assert is_subtype(Union[Any, int], Any) + assert is_subtype(Union[Any, None], Any) + assert is_subtype(Union[Any, None], Union[Any, None]) + assert is_subtype(dict, Any, ) + assert not is_subtype(Any, dict) if __name__ == '__main__': unittest.main()