Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.6 release #87

Merged
merged 16 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
566 changes: 566 additions & 0 deletions examples/2D_tutorials/Flow_matching_tutorial.ipynb

Large diffs are not rendered by default.

228 changes: 228 additions & 0 deletions examples/2D_tutorials/SF2M_tutorial.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.

<p align="center">
<img src="../../assets/169_generated_samples_otcfm.png" width="600"/>
<img src="../../../assets/169_generated_samples_otcfm.png" width="600"/>
</p>

To reproduce the experiments and save the weights, install the requirements from the main repository and then run (runs on a single RTX 2080 GPU):
Expand Down
File renamed without changes.
File renamed without changes.
470 changes: 470 additions & 0 deletions examples/images/conditional_mnist.ipynb

Large diffs are not rendered by default.

362 changes: 362 additions & 0 deletions examples/images/mnist_example.ipynb

Large diffs are not rendered by default.

201 changes: 0 additions & 201 deletions examples/notebooks/SF2M_2D_example.ipynb

This file was deleted.

467 changes: 0 additions & 467 deletions examples/notebooks/conditional_mnist.ipynb

This file was deleted.

363 changes: 0 additions & 363 deletions examples/notebooks/mnist_example.ipynb

This file was deleted.

831 changes: 0 additions & 831 deletions examples/notebooks/single-cell_example.ipynb

This file was deleted.

856 changes: 856 additions & 0 deletions examples/single_cell/single-cell_example.ipynb

Large diffs are not rendered by default.

28 changes: 0 additions & 28 deletions torchcfm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,3 @@ def plot_trajectories(traj):
plt.xticks([])
plt.yticks([])
plt.show()


class SDE(torch.nn.Module):
noise_type = "diagonal"
sde_type = "ito"

def __init__(self, ode_drift, score, noise=1.0, reverse=False):
super().__init__()
self.drift = ode_drift
self.score = score
self.reverse = reverse
self.noise = noise

# Drift
def f(self, t, y):
if self.reverse:
t = 1 - t
if len(t.shape) == len(y.shape):
x = torch.cat([y, t], 1)
else:
x = torch.cat([y, t.repeat(y.shape[0])[:, None]], 1)
if self.reverse:
return -self.drift(x) + self.score(x)
return self.drift(x) + self.score(x)

# Diffusion
def g(self, t, y):
return torch.ones_like(t) * torch.ones_like(y) * self.noise
2 changes: 1 addition & 1 deletion torchcfm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.5"
__version__ = "1.0.6"
Loading