Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add padding of t in TargetConditionalFlowMatcher #68

Merged
merged 1 commit into from
Nov 10, 2023

Conversation

francesco-vaselli
Copy link
Contributor

Hello and thank you so much for the great package :)

What does this PR do?

Using the class TargetConditionalFlowMatcher results in the error:

[/usr/local/lib/python3.10/dist-packages/torchcfm/conditional_flow_matching.py](https://localhost:8080/#) in sample_xt(self, x0, x1, t, epsilon)
    121         [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
    122         """
--> 123         mu_t = self.compute_mu_t(x0, x1, t)
    124         sigma_t = self.compute_sigma_t(t)
    125         sigma_t = pad_t_like_x(sigma_t, x0)

[/usr/local/lib/python3.10/dist-packages/torchcfm/conditional_flow_matching.py](https://localhost:8080/#) in compute_mu_t(***failed resolving arguments***)
    327         """
    328         del x0
--> 329         return t * x1
    330 
    331     def compute_sigma_t(self, t):

RuntimeError: The size of tensor a (256) must match the size of tensor b (2) at non-singleton dimension 1

To fix this I have added the call to the existing function "pad_t_like_x()" in both the "compute_mu()" and "compute_conditional_flow()" methods of the class. This does not introduce any dependency and solves the issue.

Doubt about the performance of the modified class

I have tested the modified TargetConditionalFlowMatcher in the 8gaussians to moons example. With the changes, the training code works but the model does not converge. Performance is terrible when compared with the standard ConditionalFlowMatcher. I don't see how this could be a fault of this PR; is this a problem of the TargetFM algorithm itself? Still, I find it a bit strange that it is not capable of converging for this simple problem. My tests are collected here.

Let me know if I can provide anything else,
Best regards,
Francesco

@kilianFatras
Copy link
Collaborator

kilianFatras commented Nov 10, 2023

Hello Francesco!

Thank you for your PR and pointing out the bug. We also found it earlier this week and I am working on a huge PR at the moment which will correct it but also add tests on all functions. We should have written tests much earlier... It should be released earlier next week.

Yes, the problem is from the TargetFM method! The reason is that it needs a Gaussian source distribution. Generalizing to other source distribution was the first motivation to create the Independent-CFM method.

@kilianFatras
Copy link
Collaborator

Hello again Francesco,

I have checked with my PRs and yours is good to go actually! I am sorry for the confusion. I thought I added more elements than the fix but it is fine... I will merge your PR after reviewing it. It should be done today. Thank you for your contribution to TorchCFM!

@francesco-vaselli
Copy link
Contributor Author

Great! Thank you so much for the well maintained package and for clarifying my doubts

@kilianFatras
Copy link
Collaborator

Thank you very much for your contribution! Welcome to TorchCFM contributors 🥇 ! I am merging the PR now! I wish you a wonderful day.

@kilianFatras kilianFatras merged commit 21cd0c8 into atong01:main Nov 10, 2023
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants