Skip to content

Commit

Permalink
Added support for dispatch on Type[...] when enable_generics=True
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Jul 21, 2024
1 parent 373aef0 commit 43a0f26
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
33 changes: 18 additions & 15 deletions runtype/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@ class DispatchError(Exception):


# TODO: Remove test_subtypes, replace with support for Type[], like isa(t, Type[t])
@dataclass
class MultiDispatch:
"""Creates a dispatch group for multiple dispatch
Parameters:
typesystem - instance for interfacing with the typesystem
test_subtypes: indices of params that should be matched by subclass instead of isinstance.
"""
typesystem: TypeSystem
test_subtypes: Sequence[int] = ()
enable_generics: bool = False

def __init__(self, typesystem: TypeSystem, test_subtypes: Sequence[int] = ()):
def __post_init__(self):
self.fname_to_tree: Dict[str, TypeTree] = {}
self.typesystem: TypeSystem = typesystem
if test_subtypes:
if self.test_subtypes:
warnings.warn("The test_subtypes option is deprecated and will be removed in the future."
"Use typing.Type[t] instead.", DeprecationWarning)

self.test_subtypes = test_subtypes

def __call__(self, func=None, *, priority=None):
"""Decorate the function
Expand All @@ -54,7 +56,7 @@ def __call__(self, func=None, *, priority=None):
tree = self.fname_to_tree[fname]
except KeyError:
tree = self.fname_to_tree[fname] = TypeTree(
fname, self.typesystem, self.test_subtypes
fname, self.typesystem, self.test_subtypes, enable_generics=self.enable_generics
)

tree.define_function(func)
Expand Down Expand Up @@ -106,14 +108,16 @@ class TypeTree:
_cache: Dict[tuple, Callable]
typesystem: TypeSystem
test_subtypes: Sequence[int]
enable_generics: bool

def __init__(self, name: str, typesystem: TypeSystem, test_subtypes: Sequence[int]):
def __init__(self, name: str, typesystem: TypeSystem, test_subtypes: Sequence[int], enable_generics=False):
self.root = TypeNode()
self._cache = {}
self.name = name
self.typesystem = typesystem
self.test_subtypes = test_subtypes
self._get_type = self.typesystem.get_type
self.enable_generics = enable_generics

if self.test_subtypes:
# Deprecated!!
Expand All @@ -132,13 +136,7 @@ def get_arg_types(self, args):
def find_function(self, args):
nodes = [self.root]
for i, a in enumerate(args):
nodes = [
n
for node in nodes
for n in node.follow_arg(
a, self.typesystem, test_subtype=i in self.test_subtypes
)
]
nodes = [ n for node in nodes for n in node.follow_arg( a, self.typesystem, test_subtype=i in self.test_subtypes) ]

funcs = [node.func for node in nodes if node.func]

Expand All @@ -164,10 +162,15 @@ def _old_find_function_cached(self, args):

def find_function_cached(self, args):
"Memoized version of find_function"
if self.enable_generics:
sig = tuple(a if isinstance(a, type) else self._get_type(a)
for a in args)
else:
sig = tuple(map(self._get_type, args))

try:
return self._cache[tuple(map(self._get_type, args))]
return self._cache[sig]
except KeyError:
sig = tuple(map(self._get_type, args))
f = self.find_function(args)
self._cache[sig] = f
return f
Expand Down
11 changes: 6 additions & 5 deletions runtype/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextvars
from contextlib import contextmanager


def get_func_signatures(typesystem, f):
sig = inspect.signature(f)
typesigs = []
Expand All @@ -13,9 +14,9 @@ def get_func_signatures(typesystem, f):
t = p.annotation
if t is sig.empty:
t = typesystem.default_type
else:
# Canonize to detect more collisions on construction, instead of during dispatch
t = typesystem.to_canonical_type(t)

# Canonicalize to detect more collisions on construction, instead of during dispatch
t = typesystem.to_canonical_type(t)

if p.default is not p.empty:
# From now on, everything is optional
Expand All @@ -28,7 +29,7 @@ def get_func_signatures(typesystem, f):


class ContextVar:
def __init__(self, default, name=''):
def __init__(self, default, name=""):
self._var = contextvars.ContextVar(name, default=default)

def get(self):
Expand All @@ -40,4 +41,4 @@ def __call__(self, value):
try:
yield
finally:
self._var.reset(token)
self._var.reset(token)

0 comments on commit 43a0f26

Please sign in to comment.