Skip to content

Commit

Permalink
Merge pull request #49 from erezsh/dev
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
erezsh authored Mar 1, 2024
2 parents 5415e03 + edcb3d6 commit 186087a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 4 deletions.
2 changes: 2 additions & 0 deletions runtype/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def create(cls, types):
if isinstance(t, SumType):
# Optimization: Flatten recursive SumTypes
x |= set(t.types)
elif isinstance(t, AnyType):
return t
else:
x.add(t)

Expand Down
4 changes: 2 additions & 2 deletions runtype/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __call__(self, f):
tree = self.fname_to_tree[fname] = TypeTree(fname, self.typesystem, self.test_subtypes)

tree.define_function(f)
find_function_cached = tree.find_function_cached

@wraps(f)
def dispatched_f(*args, **kw):
f = tree.find_function_cached(args)
return f(*args, **kw)
return find_function_cached(args)(*args, **kw)

dispatched_f.__dispatcher__ = self
return dispatched_f
Expand Down
4 changes: 2 additions & 2 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,11 @@ def _to_canon(self, t):
t = _forwardref_evaluate(t, self.frame.f_globals, self.frame.f_locals, frozenset())

if isinstance(t, tuple):
return SumType([to_canon(x) for x in t])
return SumType.create([to_canon(x) for x in t])

if hasattr(types, 'UnionType') and isinstance(t, types.UnionType):
res = [to_canon(x) for x in t.__args__]
return SumType(res)
return SumType.create(res)

origin = getattr(t, '__origin__', None)
if hasattr(typing, '_AnnotatedAlias') and isinstance(t, typing._AnnotatedAlias):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,21 @@ def f(a: int):
f.__module__ = 'a'
self.assertRaises(ValueError, multidispatch, f)

def test_none(self):
dp = Dispatch()

@dp
def f(t: None):
return "none"

@dp
def f(t: int):
return "int"

assert f(2) == "int"
assert f(None) == "none"


class TestDataclass(TestCase):
def setUp(self):
pass
Expand Down
4 changes: 4 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def test_pytypes2(self):
self.assertRaises(ValueError, type_caster.to_canon, typing.Tuple[int, str, ...])


def test_pytypes3(self):
assert Any + Int == Any
assert Int + Any == Any


def test_canonize_pytypes(self):
pytypes = [
Expand Down

0 comments on commit 186087a

Please sign in to comment.