Skip to content

Commit

Permalink
Merge pull request #67 from erezsh/dev4
Browse files Browse the repository at this point in the history
Fix: Throw error when attempting to dispatch on literal
  • Loading branch information
erezsh authored Oct 15, 2024
2 parents 4c177b0 + 625d149 commit 1f7ee93
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
4 changes: 4 additions & 0 deletions runtype/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def define_function(self, f):
for signature in get_func_signatures(self.typesystem, f):
node = self.root
for t in signature:
if not isinstance(t, type):
# XXX this is a temporary fix for preventing certain types from being used for dispatch
if not getattr(t, 'ALLOW_DISPATCH', True):
raise ValueError(f"Type {t} cannot be used for dispatch")
node = node.follow_type[t]

if node.func is not None:
Expand Down
5 changes: 5 additions & 0 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def test_instance(self, obj, sampler=None):
class OneOf(PythonType):
values: typing.Sequence

ALLOW_DISPATCH = False

def __init__(self, values):
self.values = values

Expand All @@ -218,6 +220,7 @@ def cast_from(self, obj):
raise TypeMismatchError(obj, self)



class GenericType(base_types.GenericType, PythonType):
base: PythonDataType
item: PythonType
Expand Down Expand Up @@ -448,6 +451,8 @@ def cast_from(self, obj):


class _NoneType(OneOf):
ALLOW_DISPATCH = True # Make an exception

def __init__(self):
super().__init__([None])

Expand Down
18 changes: 18 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,24 @@ def f(t: Tree[int]):

f(Tree())

def test_literal_dispatch(self):
try:
@multidispatch
def f(x: typing.Literal[1]):
return 1

@multidispatch
def f(x: typing.Literal[2]):
return 2
except ValueError:
pass
else:
assert False

# If it was working..
# assert f(1) == 1
# assert f(2) == 2


class TestDataclass(TestCase):
def setUp(self):
Expand Down

0 comments on commit 1f7ee93

Please sign in to comment.