Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating instances of jaxtyped dataclasses is slow #232

Open
padix-key opened this issue Jul 17, 2024 · 5 comments
Open

Creating instances of jaxtyped dataclasses is slow #232

padix-key opened this issue Jul 17, 2024 · 5 comments

Comments

@padix-key
Copy link
Contributor

Annotating a dataclass with @jaxtyped makes creating instances of that class ~1000x slower.
This is especially problematic in cases where the entire package is jaxtyped with install_import_hook(), because it is not possible to exclude a frequently used dataclass from being jaxtyped.

Here is a small benchmark:

from dataclasses import dataclass
import time
from jaxtyping import jaxtyped
from beartype import beartype


N = 1000


class VanillaClass:

    def __init__(self, foo: str):
        self.foo = foo

@dataclass
class VanillaDataclass:

    foo: str

@jaxtyped(typechecker=beartype)
class JaxtypedClass:

    def __init__(self, foo: str):
        self.foo = foo

@jaxtyped(typechecker=beartype)
@dataclass
class JaxtypedDataclass:

    foo: str

@beartype
class BeartypeClass:

    def __init__(self, foo: str):
        self.foo = foo

@beartype
@dataclass
class BeartypeClassDataclass:

    foo: str


for c in [
    VanillaClass,
    VanillaDataclass,
    JaxtypedClass,
    JaxtypedDataclass,
    BeartypeClass,
    BeartypeClassDataclass
]:
    now = time.time_ns()
    for _ in range(N):
        c("foo")
    run_time = (time.time_ns() - now) / N
    print(f"{c.__name__:>25}: {run_time} ns")

Output:

             VanillaClass: 98.0 ns
         VanillaDataclass: 128.0 ns
            JaxtypedClass: 125.0 ns
        JaxtypedDataclass: 282535.0 ns
            BeartypeClass: 223.0 ns
        BeartypeDataclass: 243.0 ns
@padix-key
Copy link
Contributor Author

I do not know if this helps, but quick profiling revealed that most of the time is spend in the _check_dataclass_annotations() method.

@patrick-kidger
Copy link
Owner

Hmm, is this not just the overhead from doing the actual type checking itself?

For what it's worth I don't think we currently respect @typing.no_type_check for dataclasses (only functions), but I'd be happy to take a PR updating that!

@padix-key
Copy link
Contributor Author

padix-key commented Jul 22, 2024

I am not very used to the actual beartype and jaxtyping code, so take everything I say with a grain of salt, but I think it is suspicious that beartype is ~1000x times faster than jaxtyped. In my impression they achieve the same thing in this case, as no special jaxtyping array syntax is used.

@aldanor
Copy link

aldanor commented Aug 24, 2024

Did a bit of line profiling on the above example, see below.

So while there's ~20% of time that could probably be saved (e.g. by caching results of the operations except the jaxtyped() call itself), it's mostly the type checking call, unfortunately.

Edit: hold on: isn't the actual type checking call just 3.6% of time?... (the last line)

In which case, can't we simply cache this f somewhere in the dataclass itself, or somewhere else?

Timer unit: 1e-06 s

Total time: 0.201161 s
Function: _check_dataclass_annotations at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     1                                           def _check_dataclass_annotations(self, typechecker):
     2                                               """Creates and calls a function that checks the attributes of `self`
     3                                           
     4                                               `self` should be a dataclass instance. `typechecker` should be e.g.
     5                                               `beartype.beartype` or `typeguard.typechecked`.
     6                                               """
     7      1000       1748.0      1.7      0.9      parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
     8      1000        113.0      0.1      0.1      values = {}
     9      2000       1397.0      0.7      0.7      for field in dataclasses.fields(self):
    10      1000        124.0      0.1      0.1          annotation = field.type
    11      1000        161.0      0.2      0.1          if isinstance(annotation, str):
    12                                                       # Don't check stringified annotations. These are basically impossible to
    13                                                       # resolve correctly, so just skip them.
    14                                                       continue
    15      1000        742.0      0.7      0.4          if get_origin(annotation) is type:
    16                                                       args = get_args(annotation)
    17                                                       if len(args) == 1 and isinstance(args[0], str):
    18                                                           # We also special-case this one kind of partially-stringified type
    19                                                           # annotation, so as to support Equinox <v0.11.1.
    20                                                           # This was fixed in Equinox in
    21                                                           # https://github.com/patrick-kidger/equinox/pull/543
    22                                                           continue
    23      1000         75.0      0.1      0.0          try:
    24      1000        537.0      0.5      0.3              value = getattr(self, field.name)  # noqa: F841
    25      1000        117.0      0.1      0.1          except AttributeError:
    26      1000        109.0      0.1      0.1              continue  # allow uninitialised fields, which are allowed on dataclasses
    27                                           
    28                                                   parameters.append(
    29                                                       inspect.Parameter(
    30                                                           field.name,
    31                                                           inspect.Parameter.POSITIONAL_OR_KEYWORD,
    32                                                           annotation=field.type,
    33                                                       )
    34                                                   )
    35                                                   values[field.name] = value
    36                                           
    37      1000       1853.0      1.9      0.9      signature = inspect.Signature(parameters)
    38      2000      22255.0     11.1     11.1      f = _make_fn_with_signature(
    39      1000        152.0      0.2      0.1          self.__class__.__name__,
    40      1000        137.0      0.1      0.1          self.__class__.__qualname__,
    41      1000        137.0      0.1      0.1          self.__class__.__module__,
    42      1000         73.0      0.1      0.0          signature,
    43      1000         70.0      0.1      0.0          output=False,
    44                                               )
    45      1000     164024.0    164.0     81.5      f = jaxtyped(f, typechecker=typechecker)
    46      1000       7337.0      7.3      3.6      f(self, **values)

@patrick-kidger
Copy link
Owner

Oh interesting! Thank you for profiling this -- I agree, caching sounds reasonable.

I'd be happy to take a PR on this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants