-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is it possible for method annotation of an eqx.Module to pick up instance field annotations? #241
Comments
What do you mean by option 2 being 'inconsistent'? Agreed it's a little verbose but should accomplish what you're after. |
I mean that in this case, the field annotations (e.g., in/out) and the method annotations (e.g., self.in_size/self.out_size) refer to the same concept but use very different naming styles. class Linear(eqx.Module):
weight: Float[Array, "out in"]
bias: Float[Array, "out"]
# ...
@jaxtyped(typechecker=beartype)
def __call__(self, x: Float[Array, "{self.in_size}"]) -> Float[Array, "{self.out_size}"]:
return self.weight @ x + self.bias Of course, I can change Additionally, it can sometimes be challenging to directly associate the dimensions with hyperparameters, which may lead to something like this: def __call__(self, x: Float[Array, "{self.weight.shape[1]}"]) -> Float[Array, "{self.weight.shape[0]}"]:
return self.weight @ x + self.bias I appreciate the design of jaxtyping because it strikes a nice balance between static type checking, runtime type checking, and documentation readability. If annotations have to be written like |
Ah, I see what you're getting at. Indeed, something like Unfortunately I don't have a better solution to this problem; if you have a suggestion then I'm happy to hear it. FWIW this doesn't worry me too much. There are many things we could, in principle, put into the type system -- e.g. is this integer positive? -- that we typically don't. Type systems are always about picking a trade-off between verbosity and validation, so I think we're in the usual situation in this regard. :) |
@patrick-kidger I went through the jaxtyping internals and found that implementing the feature I requested wasn't too difficult. I've created a prototype and would appreciate your feedback on it. While it generally works, the design, interface, naming, internals, and edge case handling still need refinement. Regarding the semantics, since a class’s suite (including attributes and parameter annotations in method declarations) shares the same dedicated local namespace, I believe it’s intuitive to extend these semantics to the diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py
index 2d0eac7..1883b6c 100644
--- a/jaxtyping/_decorator.py
+++ b/jaxtyping/_decorator.py
@@ -30,7 +30,7 @@ from jaxtyping import AbstractArray
from ._config import config
from ._errors import AnnotationError, TypeCheckError
-from ._storage import pop_shape_memo, push_shape_memo, shape_str
+from ._storage import get_shape_memo, pop_shape_memo, push_shape_memo, shape_str
class _Sentinel:
@@ -52,7 +52,7 @@ def jaxtyped(fn, *, typechecker=_sentinel):
...
-def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
+def jaxtyped(fn=_sentinel, *, typechecker=_sentinel, contextual=False, _keep_shape_memo=False):
"""Decorate a function with this to perform runtime type-checking of its arguments
and return value. Decorate a dataclass to perform type-checking of its attributes.
@@ -252,7 +252,7 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
typechecker = None
if fn is _sentinel:
- return ft.partial(jaxtyped, typechecker=typechecker)
+ return ft.partial(jaxtyped, typechecker=typechecker, contextual=contextual)
elif inspect.isclass(fn):
if dataclasses.is_dataclass(fn) and typechecker is not None:
# This does not check that the arguments passed to `__init__` match the
@@ -260,6 +260,9 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
# dataclass-generated `__init__` used alongside
# `equinox.field(converter=...)`
+ assert not contextual
+ fn.__jaxtyped__ = True
+
init = fn.__init__
@ft.wraps(init)
@@ -280,7 +283,7 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
# metaclass `__call__`, because Python doesn't allow you
# monkey-patch metaclasses.
if self.__class__.__init__ is fn.__init__:
- _check_dataclass_annotations(self, typechecker)
+ _check_dataclass_annotations(self, typechecker, keep_shape_memo=False)
fn.__init__ = __init__
return fn
@@ -515,13 +518,30 @@ def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
bound = param_signature.bind(*args, **kwargs)
bound.apply_defaults()
- memos = push_shape_memo(bound.arguments)
+ if contextual:
+ assert len(args) > 0
+ self = args[0]
+ assert getattr(self, "__jaxtyped__", False)
+ try:
+ _check_dataclass_annotations(self, typechecker, keep_shape_memo=True)
+ except:
+ # 1. dataclass not initialized through init
+ # 2. field modified after initialization
+ pop_shape_memo()
+ raise
+ memos = get_shape_memo()
+ *_, arguments = memos
+ arguments.clear() # Comment to merge
+ arguments.update(bound.arguments)
+ else:
+ memos = push_shape_memo(bound.arguments)
try:
# Put this in a separate frame to make debugging easier, without
# just always ending up on the `pop_shape_memo` line below.
return wrapped_fn_impl(args, kwargs, bound, memos)
finally:
- pop_shape_memo()
+ if not _keep_shape_memo:
+ pop_shape_memo()
return wrapped_fn
@@ -534,7 +554,7 @@ class _JaxtypingContext:
pop_shape_memo()
-def _check_dataclass_annotations(self, typechecker):
+def _check_dataclass_annotations(self, typechecker, keep_shape_memo):
"""Creates and calls a function that checks the attributes of `self`
`self` should be a dataclass instance. `typechecker` should be e.g.
@@ -578,7 +598,7 @@ def _check_dataclass_annotations(self, typechecker):
signature,
output=False,
)
- f = jaxtyped(f, typechecker=typechecker)
+ f = jaxtyped(f, typechecker=typechecker, contextual=False, _keep_shape_memo=keep_shape_memo)
f(self, **values) This patch allows the following snippet to work correctly: import equinox as eqx
import jax
from jaxtyping import Array, Float, Key, jaxtyped, config
config.update("jaxtyping_remove_typechecker_stack", True)
from beartype import beartype
@jaxtyped(typechecker=beartype)
class Linear(eqx.Module):
in_size: int = eqx.field(static=True)
out_size: int = eqx.field(static=True)
weight: Float[Array, "O I"]
bias: Float[Array, "O"]
def __init__(self, in_size: int, out_size: int, *, key: Key[Array, ""]):
self.in_size = in_size
self.out_size = out_size
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
@jaxtyped(typechecker=beartype, contextual=True)
def __call__(self, x: Float[Array, "I"]) -> Float[Array, "O"]:
return self.weight @ x + self.bias
model = Linear(2, 3, key=jax.random.key(0))
x = jax.numpy.zeros((2,))
model(x) |
Okay, I quite like this approach! This seems like a nice improvement. I think my main goal would be to remove the two extra arguments to
I'm aware that the former is technically a (very minor) breaking change, that could result in extra errors. We can bump the version number to reflect this; I'd rather do that then have an extra flag. |
Detecting this at the time of class definition seems non-trivial, but it becomes much easier when the dataclass method is called, by checking via
I was aware of this approach, but when I was developing the feature, I was uncertain about how to ensure the context correctness of both functions, especially concerning the In addition, several design choices need to be made:
|
Agreed that checking at decoration time sounds hard. Checking at runtime seems doable, however: import functools as ft
import inspect
import types
def _is_method(fn, args, kwargs):
if isinstance(fn, types.FunctionType):
if len(args) > 0:
first_arg = args[0]
cls_dict = getattr(type(first_arg), "__dict__", {})
is_method = fn in cls_dict.values()
else:
parameters = inspect.signature(fn).parameters
if len(parameters) > 0:
first_arg_name = next(iter(parameters))
try:
first_arg = kwargs[first_arg_name]
except KeyError:
# Presumably we're about to get a type error from a missing
# argument. Don't do anything here, let the normal function call
# raise the error.
return
else:
cls_dict = getattr(type(first_arg), "__dict__", {})
is_method = fn in cls_dict.values()
else:
if len(kwargs) == 0:
is_method = False
else:
# Likewise, presumably a type error.
return
else:
# Included this branch just in case we have `jaxtyped` wrapping other
# objects, which may not be hashable -- which is required for the
# `fn in cls_dict` check above.
is_method = False
print(f"{fn} is method: {is_method}")
def is_method(fn):
@ft.wraps(fn)
def wrapper(*args, **kwargs):
_is_method(wrapper, args, kwargs)
return fn(*args, **kwargs)
return wrapper
@is_method
def global_function():
pass
class A:
@is_method
def method_function(self):
pass
def wrapper():
@is_method
def closure_function():
pass
return closure_function
global_function()
A().method_function()
wrapper()() No strong feelings on only applying to jaxtyped dataclasses vs all dataclasses. I think it's rare for there to be unchecked dataclasses that nonetheless use jaxtyping annotations, after all. I would like to avoid setting Not including |
Currently, unless I am missing something, for an
eqx.Module
, the instance field annotations and instance method annotations are disjoint. This prevents the module input from being sufficiently checked (see 1 in snippet). In other words, field annotations only catch bugs during module initialization but are unhelpful when the module is invoked by the caller. While I could come up with several workarounds (see 2 and 3 in snippet), they come at the cost of increased complexity.I went through the issues and did not find relevant discussions. What is the suggested way to address this? Can this be implemented in jaxtyping?
The text was updated successfully, but these errors were encountered: