Skip to content

Commit

Permalink
Merge pull request #50 from erezsh/dev
Browse files Browse the repository at this point in the history
Big optimizations
  • Loading branch information
erezsh authored Mar 5, 2024
2 parents 72018f6 + 620f418 commit d274e99
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "runtype"
version = "0.4.2"
version = "0.5.0b"
description = "Type dispatch and validation for run-time Python"
authors = ["Erez Shinan <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion runtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
assert_isa, isa, issubclass, validate_func, is_subtype)
from .pytypes import Constraint, String, Int, cv_type_checking

__version__ = "0.4.2"
__version__ = "0.5.0b"
__all__ = (
'dataclass',
'DispatchError', 'MultiDispatch',
Expand Down
24 changes: 18 additions & 6 deletions runtype/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,26 +198,28 @@ def __init__(self, base, item=Any):
class Validator(ABC):
"""Defines the validator interface."""

@abstractmethod
def validate_instance(self, obj, sampler: Optional[SamplerType] = None):
"""Validates obj, raising a TypeMismatchError if it does not conform.
If sampler is provided, it will be applied to the instance in order to
validate only a sample of the object. This approach may validate much faster,
but might miss anomalies in the data.
"""
if not self.test_instance(obj, sampler):
raise TypeMismatchError(obj, self)

@abstractmethod
def test_instance(self, obj, sampler=None):
"""Tests obj, returning a True/False for whether it conforms or not.
If sampler is provided, it will be applied to the instance in order to
validate only a sample of the object.
"""
try:
self.validate_instance(obj, sampler)
return True
except TypeMismatchError:
return False
# try:
# self.validate_instance(obj, sampler)
# return True
# except TypeMismatchError:
# return False


class Constraint(Validator, Type):
Expand All @@ -235,6 +237,16 @@ def validate_instance(self, inst, sampler=None):
if not p(inst):
raise TypeMismatchError(inst, self)

def test_instance(self, inst, sampler=None):
"""Makes sure the instance conforms by applying it to all the predicates."""
if not self.type.test_instance(inst, sampler):
return False

for p in self.predicates:
if not p(inst):
return False
return True


# fmt: off
@dp
Expand Down
27 changes: 23 additions & 4 deletions runtype/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from typing import Any, Dict, Callable, Sequence
from operator import itemgetter
import warnings

from dataclasses import dataclass

Expand All @@ -25,6 +26,10 @@ class MultiDispatch:
def __init__(self, typesystem: TypeSystem, test_subtypes: Sequence[int] = ()):
self.fname_to_tree: Dict[str, TypeTree] = {}
self.typesystem: TypeSystem = typesystem
if test_subtypes:
warnings.warn("The test_subtypes option is deprecated and will be removed in the future."
"Use typing.Type[t] instead.", DeprecationWarning)

self.test_subtypes = test_subtypes

def __call__(self, func=None, *, priority=None):
Expand Down Expand Up @@ -106,17 +111,21 @@ def __init__(self, name: str, typesystem: TypeSystem, test_subtypes: Sequence[in
self.name = name
self.typesystem = typesystem
self.test_subtypes = test_subtypes
self._get_type = self.typesystem.get_type

if self.test_subtypes:
# Deprecated!!
self.find_function_cached = self._old_find_function_cached

def get_arg_types(self, args):
get_type = self.typesystem.get_type
if self.test_subtypes:
# TODO can be made more efficient
return tuple(
(a if i in self.test_subtypes else get_type(a))
(a if i in self.test_subtypes else self._get_type(a))
for i, a in enumerate(args)
)

return tuple(map(get_type, args))
return tuple(map(self._get_type, args))

def find_function(self, args):
nodes = [self.root]
Expand All @@ -141,7 +150,7 @@ def find_function(self, args):
((f, _sig),) = funcs
return f

def find_function_cached(self, args):
def _old_find_function_cached(self, args):
"Memoized version of find_function"
sig = self.get_arg_types(args)
try:
Expand All @@ -151,6 +160,16 @@ def find_function_cached(self, args):
self._cache[sig] = f
return f

def find_function_cached(self, args):
"Memoized version of find_function"
try:
return self._cache[tuple(map(self._get_type, args))]
except KeyError:
sig = tuple(map(self._get_type, args))
f = self.find_function(args)
self._cache[sig] = f
return f

def define_function(self, f):
for signature in get_func_signatures(self.typesystem, f):
node = self.root
Expand Down
90 changes: 72 additions & 18 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def cast_from(self, obj):


class AnyType(base_types.AnyType, PythonType):
def validate_instance(self, obj, sampler=None):
def test_instance(self, obj, sampler=None):
return True

def cast_from(self, obj):
Expand All @@ -89,6 +89,16 @@ def validate_instance(self, obj, sampler=None):
raise LengthMismatchError(self, obj)
for type_, item in zip(self.types, obj):
type_.validate_instance(item, sampler)

def test_instance(self, obj, sampler=None):
if not isinstance(obj, tuple):
return False
if self.types and len(obj) != len(self.types):
return False
for type_, item in zip(self.types, obj):
if not type_.test_instance(item, sampler):
return False
return True


class SumType(base_types.SumType, PythonType):
Expand All @@ -102,12 +112,33 @@ def __init__(self, types: typing.Sequence[PythonType]):
types = rest + [OneOf([v for t in one_ofs for v in t.values])]
super().__init__(types)

# Optimization for instance validation
data_types = []
self.other_types = []
for t in types:
if isinstance(t, PythonDataType):
data_types.append(t.kernel)
else:
self.other_types.append(t)
self.data_types = tuple(data_types)


def validate_instance(self, obj, sampler=None):
for t in self.types:
if isinstance(obj, self.data_types):
return
for t in self.other_types:
if t.test_instance(obj):
return
raise TypeMismatchError(obj, self)

def test_instance(self, obj, sampler=None):
if isinstance(obj, self.data_types):
return True
for t in self.other_types:
if t.test_instance(obj):
return True
return False

def cast_from(self, obj):
for t in self.types:
with suppress(TypeError):
Expand All @@ -129,9 +160,8 @@ class PythonDataType(DataType, PythonType):
def __init__(self, kernel, supertypes={Any}):
self.kernel = kernel

def validate_instance(self, obj, sampler=None):
if not isinstance(obj, self.kernel):
raise TypeMismatchError(obj, self)
def test_instance(self, obj, sampler=None):
return isinstance(obj, self.kernel)

def __repr__(self):
try:
Expand All @@ -144,9 +174,7 @@ def cast_from(self, obj):
# kernel is probably a class. Cast the dict into the class.
return self.kernel(**obj)

try:
self.validate_instance(obj)
except TypeMismatchError:
if not self.test_instance(obj):
cast = getattr(self.kernel, 'cast_from', None)
if cast:
return cast(obj)
Expand All @@ -155,9 +183,8 @@ def cast_from(self, obj):
return obj

class TupleType(PythonType):
def validate_instance(self, obj, sampler=None):
if not isinstance(obj, tuple):
raise TypeMismatchError(obj, self)
def test_instance(self, obj, sampler=None):
return isinstance(obj, tuple)


# cv_type_checking allows the user to define different behaviors for their objects
Expand All @@ -172,11 +199,10 @@ class OneOf(PythonType):
def __init__(self, values):
self.values = values

def validate_instance(self, obj, sampler=None):
def test_instance(self, obj, sampler=None):
tok = cv_type_checking.set(True)
try:
if obj not in self.values:
raise TypeMismatchError(obj, self)
return obj in self.values
finally:
cv_type_checking.reset(tok)

Expand All @@ -189,12 +215,26 @@ def cast_from(self, obj):


class GenericType(base_types.GenericType, PythonType):
base: PythonType
item: PythonType

def __init__(self, base: PythonType, item=Any):
return super().__init__(base, item)


class SequenceType(GenericType):

def test_instance(self, obj, sampler=None):
if not self.base.test_instance(obj):
return False
if self.item is not Any:
if sampler:
obj = sampler(obj)
for item in obj:
if not self.item.test_instance(item, sampler):
return False
return True

def validate_instance(self, obj, sampler=None):
self.base.validate_instance(obj)
if self.item is not Any:
Expand All @@ -217,6 +257,7 @@ def cast_from(self, obj):


class DictType(GenericType):
item: ProductType

def __init__(self, base: PythonType, item=Any*Any):
super().__init__(base)
Expand All @@ -236,6 +277,21 @@ def validate_instance(self, obj, sampler=None):
kt.validate_instance(k, sampler)
vt.validate_instance(v, sampler)

def test_instance(self, obj, sampler=None):
if not self.base.test_instance(obj):
return False
if self.item is not Any:
kt, vt = self.item.types
items = obj.items()
if sampler:
items = sampler(items)
for k, v in items:
if not kt.test_instance(k, sampler):
return False
if not vt.test_instance(v, sampler):
return False
return True

def __getitem__(self, item):
assert self.item == Any*Any
return type(self)(self.base, item)
Expand All @@ -261,11 +317,9 @@ def __repr__(self):
return '%s[%s, ...]' % (self.base, self.item)

class TypeType(GenericType):

def validate_instance(self, obj, sampler=None):
def test_instance(self, obj, sampler=None):
t = type_caster.to_canon(obj)
if not t <= self.item:
raise TypeMismatchError(obj, self)
return t <= self.item


Object = PythonDataType(object)
Expand Down
12 changes: 7 additions & 5 deletions runtype/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def isa(obj, t):
Behaves like Python's isinstance, but supports the ``typing`` module and constraints.
"""
try:
ensure_isa(obj, t)
return True
except TypeMismatchError:
return False
ct = type_caster.to_canon(t)
return ct.test_instance(obj)
# try:
# ensure_isa(obj, t)
# return True
# except TypeMismatchError:
# return False


def assert_isa(obj, t):
Expand Down

0 comments on commit d274e99

Please sign in to comment.