From 51dc223cbcfb35b77c365a8618099305d22d8519 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 15 Aug 2023 22:19:57 +0300 Subject: [PATCH] Fix for tuple[x, ...]; Fix for comparing generic types; support for types.UnionType --- README.md | 2 +- runtype/base_types.py | 4 ++-- runtype/pytypes.py | 15 +++++++++++++-- tests/test_basic.py | 14 ++++++++++---- tests/test_types.py | 22 +++++++++++++++++++--- 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index f17f3a4..8cb2339 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ It is: - :star: [**type utilities**](https://runtype.readthedocs.io/en/latest/types.html) - Provides a set of classes to implement your own type-system. - Used by runtype itself, to emulate the Python type-system. - + ## Docs diff --git a/runtype/base_types.py b/runtype/base_types.py index 8228599..7cc3d10 100644 --- a/runtype/base_types.py +++ b/runtype/base_types.py @@ -202,7 +202,7 @@ def __eq__(self, other): def __le__(self, other): - if isinstance(other, GenericType): + if isinstance(other, type(self)): return self.base <= other.base and self.item <= other.item elif isinstance(other, DataType): @@ -211,7 +211,7 @@ def __le__(self, other): return NotImplemented def __ge__(self, other): - if isinstance(other, GenericType): + if isinstance(other, type(self)): return self.base >= other.base and self.item >= other.item elif isinstance(other, DataType): diff --git a/runtype/pytypes.py b/runtype/pytypes.py index b1b1085..a538643 100644 --- a/runtype/pytypes.py +++ b/runtype/pytypes.py @@ -2,6 +2,7 @@ Python Types - contains an implementation of a Runtype type system that is parallel to the Python type system. """ +import types from abc import abstractmethod, ABC from contextlib import suppress import collections @@ -166,7 +167,9 @@ def __le__(self, other): return isinstance(other, TupleType) def __ge__(self, other): - if isinstance(other, TupleType): + if isinstance(other, TupleEllipsisType): + return True + elif isinstance(other, TupleType): return True elif isinstance(other, DataType): return False @@ -286,6 +289,10 @@ def cast_from(self, obj): kt, vt = self.item.types return {kt.cast_from(k): vt.cast_from(v) for k, v in obj.items()} +class TupleEllipsisType(SequenceType): + def __repr__(self): + return '%s[%s, ...]' % (self.base, self.item) + Object = PythonDataType(object) Iter = SequenceType(PythonDataType(collections.abc.Iterable)) @@ -296,7 +303,7 @@ def cast_from(self, obj): Dict = DictType(PythonDataType(dict)) Mapping = DictType(PythonDataType(abc.Mapping)) Tuple = TupleType() -TupleEllipsis = SequenceType(PythonDataType(tuple)) +TupleEllipsis = TupleEllipsisType(PythonDataType(tuple)) # Float = PythonDataType(float) Bytes = PythonDataType(bytes) Callable = PythonDataType(abc.Callable) # TODO: Generic @@ -470,6 +477,10 @@ def _to_canon(self, t): # Python 3.6 return to_canon(t.__args__[0]) + if isinstance(t, types.UnionType): + res = [to_canon(x) for x in t.__args__] + return SumType(res) + try: t.__origin__ except AttributeError: diff --git a/tests/test_basic.py b/tests/test_basic.py index 9300ea6..68dda6f 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -94,16 +94,22 @@ def test_basic(self): assert not isa(frozenset({'a'}), Set[int]) - def test_basic2(self): + def test_issubclass(self): assert issubclass(List[Tuple], list) if hasattr(typing, 'Annotated'): a = typing.Annotated[int, range(1, 10)] - assert is_subtype(a, int) - assert is_subtype(int, a) + assert issubclass(a, int) + assert issubclass(int, a) assert isa(1, a) - def test_issubclass(self): + assert issubclass(typing.Tuple, tuple) + assert issubclass(typing.Tuple[int], tuple) + assert issubclass(typing.Tuple[int, ...], tuple) + assert issubclass(typing.Tuple[int], typing.Tuple[typing.Union[int, str]]) + assert issubclass(typing.Tuple[int, ...], typing.Tuple[typing.Union[int, str], ...]) + + def test_issubclass_tuple(self): # test class tuple t = int, float assert issubclass(int, t) diff --git a/tests/test_types.py b/tests/test_types.py index 0f0424c..e23f5ba 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -162,9 +162,25 @@ def test_pytypes(self): assert type_caster.to_canon(typing.Dict[int, int]).cast_from({}) == {} assert type_caster.to_canon(typing.Dict[int, int]).cast_from([]) == {} - t = type_caster.to_canon(typing.Tuple[int, ...]) - assert t.test_instance((1,2,3)) - assert not t.test_instance((1,2,3, 'a')) + tpl0 = type_caster.to_canon(typing.Tuple) + tpl1 = type_caster.to_canon(typing.Tuple[int]) + tpl2 = type_caster.to_canon(typing.Tuple[int, ...]) + tpl0b = type_caster.to_canon(tuple) + tpl3 = type_caster.to_canon(typing.Tuple[typing.Union[int, str]]) + tpl4 = type_caster.to_canon(typing.Tuple[typing.Union[int, str], ...]) + assert tpl0 is tpl0b + assert tpl1 <= tpl0 + assert tpl2 <= tpl0 + + assert tpl3 <= tpl0 + assert tpl1 <= tpl3 + assert not tpl3 <= tpl1 + assert not tpl0 <= tpl3 + + assert tpl2 <= tpl4 + + assert tpl2.test_instance((1,2,3)) + assert not tpl2.test_instance((1,2,3, 'a')) if sys.version_info >= (3, 11): self.assertRaises(ValueError, type_caster.to_canon, typing.Tuple[...]) self.assertRaises(ValueError, type_caster.to_canon, typing.Tuple[int, str, ...])