Skip to content

Commit

Permalink
Add validate hook
Browse files Browse the repository at this point in the history
  • Loading branch information
wodesuck committed Oct 27, 2024
1 parent ff9c75b commit 0014064
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
9 changes: 9 additions & 0 deletions jsonschema/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ def __call__(
[referencing.jsonschema.Schema],
Iterable[tuple[str, Any]],
]

class ValidateHook(Protocol):
def __call__(
self,
is_valid: bool,
instance: Any,
schema: referencing.jsonschema.Schema,
) -> None:
...
4 changes: 3 additions & 1 deletion jsonschema/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# therefore, only import at type-checking time (to avoid circular references),
# but use `jsonschema` for any types which will otherwise not be resolvable
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence

import referencing.jsonschema

Expand Down Expand Up @@ -102,6 +102,8 @@ class Validator(Protocol):
#: A function which given a schema returns its ID.
ID_OF: _typing.id_of

VALIDATE_HOOKS: ClassVar[Sequence]

#: The schema that will be used to validate instances
schema: Mapping | bool

Expand Down
46 changes: 44 additions & 2 deletions jsonschema/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def create(
applicable_validators: _typing.ApplicableValidators = methodcaller(
"items",
),
validate_hooks: Sequence[_typing.ValidateHook] = (),
):
"""
Create a new validator class.
Expand Down Expand Up @@ -207,6 +208,16 @@ def create(
implement similar behavior, you can typically ignore this argument
and leave it at its default.
validate_hooks:
A list of callables, will be called after validate.
Each callable should take 4 arguments:
1. is valid or not
2. the instance
3. the schema
Returns:
a new `jsonschema.protocols.Validator` class
Expand All @@ -220,6 +231,10 @@ def create(
default=referencing.Specification.OPAQUE,
)

def _call_validate_hooks(is_valid, instance, schema):
for hook in validate_hooks:
hook(is_valid, instance, schema)

@define
class Validator:

Expand All @@ -228,6 +243,7 @@ class Validator:
TYPE_CHECKER = type_checker
FORMAT_CHECKER = format_checker_arg
ID_OF = staticmethod(id_of)
VALIDATE_HOOKS = list(validate_hooks) # noqa: RUF012

_APPLICABLE_VALIDATORS = applicable_validators
_validators = field(init=False, repr=False, eq=False)
Expand Down Expand Up @@ -368,6 +384,7 @@ def iter_errors(self, instance, _schema=None):
_schema, validators = self.schema, self._validators

if _schema is True:
_call_validate_hooks(True, instance, _schema)
return
elif _schema is False:
yield exceptions.ValidationError(
Expand All @@ -377,8 +394,10 @@ def iter_errors(self, instance, _schema=None):
instance=instance,
schema=_schema,
)
_call_validate_hooks(False, instance, _schema)
return

is_valid = True
for validator, k, v in validators:
errors = validator(self, v, instance, _schema) or ()
for error in errors:
Expand All @@ -392,7 +411,9 @@ def iter_errors(self, instance, _schema=None):
)
if k not in {"if", "$ref"}:
error.schema_path.appendleft(k)
is_valid = False
yield error
_call_validate_hooks(is_valid, instance, _schema)

def descend(
self,
Expand All @@ -403,6 +424,7 @@ def descend(
resolver=None,
):
if schema is True:
_call_validate_hooks(True, instance, schema)
return
elif schema is False:
yield exceptions.ValidationError(
Expand All @@ -412,6 +434,7 @@ def descend(
instance=instance,
schema=schema,
)
_call_validate_hooks(False, instance, schema)
return

if self._ref_resolver is not None:
Expand All @@ -423,6 +446,7 @@ def descend(
)
evolved = self.evolve(schema=schema, _resolver=resolver)

is_valid = True
for k, v in applicable_validators(schema):
validator = evolved.VALIDATORS.get(k)
if validator is None:
Expand All @@ -444,10 +468,15 @@ def descend(
error.path.appendleft(path)
if schema_path is not None:
error.schema_path.appendleft(schema_path)
is_valid = False
yield error
_call_validate_hooks(is_valid, instance, schema)

def validate(self, *args, **kwargs):
for error in self.iter_errors(*args, **kwargs):
def validate(self, instance, _schema=None):
for error in self.iter_errors(instance, _schema):
if _schema is None:
_schema = self.schema
_call_validate_hooks(False, instance, _schema)
raise error

def is_type(self, instance, type):
Expand Down Expand Up @@ -498,6 +527,8 @@ def is_valid(self, instance, _schema=None):
self = self.evolve(schema=_schema)

error = next(self.iter_errors(instance), None)
if error is not None:
_call_validate_hooks(False, instance, self.schema)
return error is None

evolve_fields = [
Expand All @@ -520,6 +551,7 @@ def extend(
version=None,
type_checker=None,
format_checker=None,
validate_hooks=(),
):
"""
Create a new validator class by extending an existing one.
Expand Down Expand Up @@ -565,6 +597,12 @@ def extend(
If unprovided, the format checker of the extended
`jsonschema.protocols.Validator` will be carried along.
validate_hooks (collections.abc.Sequence):
a list of new validate hooks to extend with, whose
structure is as in `create`.
Returns:
a new `jsonschema.protocols.Validator` class extending the one
Expand All @@ -584,6 +622,9 @@ def extend(
all_validators = dict(validator.VALIDATORS)
all_validators.update(validators)

all_validate_hooks = list(validator.VALIDATE_HOOKS)
all_validate_hooks.extend(validate_hooks)

if type_checker is None:
type_checker = validator.TYPE_CHECKER
if format_checker is None:
Expand All @@ -596,6 +637,7 @@ def extend(
format_checker=format_checker,
id_of=validator.ID_OF,
applicable_validators=validator._APPLICABLE_VALIDATORS,
validate_hooks=all_validate_hooks,
)


Expand Down

0 comments on commit 0014064

Please sign in to comment.