Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Splicing / variadic symbolic expressions #265

Open
martenlienen opened this issue Nov 12, 2024 · 1 comment
Open

Splicing / variadic symbolic expressions #265

martenlienen opened this issue Nov 12, 2024 · 1 comment
Labels
feature New feature

Comments

@martenlienen
Copy link

Would it be possible to make the following code snippet work?

import torch
from beartype import beartype
from jaxtyping import Float, jaxtyped
from torch import Tensor

class A:
    def __init__(self, shape: tuple[int, ...]):
        self.shape = shape

    @jaxtyped(typechecker=beartype)
    def forward(self, x: Float[Tensor, "... {self.shape}"]) -> Float[Tensor, "..."]:
        return x.flatten(start_dim=-len(self.shape)).sum(dim=-1)

a = A((3, 10, 5))
x = torch.randn((7, 3, 4, 5))
print(a.forward(x))

At the moment it does not work as far as I can tell, because {self.shape} is only matched against a single dimension of x. Is there a way to evaluate the expression and splice in the tuple value into the type before the type gets matched against the dimensions? Maybe with something like a *{self.shape} syntax?

@patrick-kidger
Copy link
Owner

Yup, this is a known issue. I don't have a nice way to fix this right now -- this is quite a complicated corner of jaxtyping! -- but I'd be happy to take a PR if someone feels like taking this on.

If need be you can maybe do something like str(self.shape).replace(",", " ")[1:-1] but that's obviously pretty messy.

@patrick-kidger patrick-kidger added the feature New feature label Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants