-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into forest_flow
- Loading branch information
Showing
20 changed files
with
392 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
name: Runner Tests | ||
|
||
on: | ||
push: | ||
branches: [main] | ||
pull_request: | ||
branches: [main, "release/*"] | ||
|
||
jobs: | ||
run_tests_ubuntu: | ||
runs-on: ${{ matrix.os }} | ||
|
||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest] | ||
python-version: ["3.8", "3.9", "3.10"] | ||
|
||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r runner-requirements.txt | ||
pip install pytest | ||
pip install sh | ||
pip install -e . | ||
- name: List dependencies | ||
run: | | ||
python -m pip list | ||
- name: Run pytest | ||
run: | | ||
pytest -v runner | ||
# upload code coverage report | ||
code-coverage-runner: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r runner-requirements.txt | ||
pip install pytest | ||
pip install pytest-cov[toml] | ||
pip install sh | ||
pip install -e . | ||
- name: Run tests and collect coverage | ||
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER | ||
|
||
- name: Upload coverage to Codecov | ||
uses: codecov/codecov-action@v3 | ||
with: | ||
name: codecov-runner | ||
verbose: true | ||
fail_ci_if_error: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,3 +35,6 @@ count = true | |
|
||
[tool.bandit] | ||
skips = ["B101", "B311"] | ||
|
||
[tool.isort] | ||
known_first_party = ["tests", "src"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import pytest | ||
|
||
from tests.helpers.run_if import RunIf | ||
from tests.helpers.run_sh_command import run_sh_command | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
"""Tests for Conditional Flow Matcher classers.""" | ||
|
||
# Author: Kilian Fatras <[email protected]> | ||
|
||
import math | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from torchcfm.conditional_flow_matching import ( | ||
ConditionalFlowMatcher, | ||
ExactOptimalTransportConditionalFlowMatcher, | ||
SchrodingerBridgeConditionalFlowMatcher, | ||
TargetConditionalFlowMatcher, | ||
VariancePreservingConditionalFlowMatcher, | ||
pad_t_like_x, | ||
) | ||
from torchcfm.optimal_transport import OTPlanSampler | ||
|
||
TEST_SEED = 1994 | ||
TEST_BATCH_SIZE = 128 | ||
SIGMA_CONDITION = { | ||
"sb_cfm": lambda x: x <= 0, | ||
} | ||
|
||
|
||
def random_samples(shape, batch_size=TEST_BATCH_SIZE): | ||
"""Generate random samples of different dimensions.""" | ||
if isinstance(shape, int): | ||
shape = [shape] | ||
return [torch.randn(batch_size, *shape), torch.randn(batch_size, *shape)] | ||
|
||
|
||
def compute_xt_ut(method, x0, x1, t_given, sigma, epsilon): | ||
if method == "vp_cfm": | ||
sigma_t = sigma | ||
mu_t = torch.cos(math.pi / 2 * t_given) * x0 + torch.sin(math.pi / 2 * t_given) * x1 | ||
computed_xt = mu_t + sigma_t * epsilon | ||
computed_ut = ( | ||
math.pi | ||
/ 2 | ||
* (torch.cos(math.pi / 2 * t_given) * x1 - torch.sin(math.pi / 2 * t_given) * x0) | ||
) | ||
elif method == "t_cfm": | ||
sigma_t = 1 - (1 - sigma) * t_given | ||
mu_t = t_given * x1 | ||
computed_xt = mu_t + sigma_t * epsilon | ||
computed_ut = (x1 - (1 - sigma) * computed_xt) / sigma_t | ||
|
||
elif method == "sb_cfm": | ||
sigma_t = sigma * torch.sqrt(t_given * (1 - t_given)) | ||
mu_t = t_given * x1 + (1 - t_given) * x0 | ||
computed_xt = mu_t + sigma_t * epsilon | ||
computed_ut = ( | ||
(1 - 2 * t_given) | ||
/ (2 * t_given * (1 - t_given) + 1e-8) | ||
* (computed_xt - (t_given * x1 + (1 - t_given) * x0)) | ||
+ x1 | ||
- x0 | ||
) | ||
elif method in ["exact_ot_cfm", "i_cfm"]: | ||
sigma_t = sigma | ||
mu_t = t_given * x1 + (1 - t_given) * x0 | ||
computed_xt = mu_t + sigma_t * epsilon | ||
computed_ut = x1 - x0 | ||
|
||
return computed_xt, computed_ut | ||
|
||
|
||
def get_flow_matcher(method, sigma): | ||
if method == "vp_cfm": | ||
fm = VariancePreservingConditionalFlowMatcher(sigma=sigma) | ||
elif method == "t_cfm": | ||
fm = TargetConditionalFlowMatcher(sigma=sigma) | ||
elif method == "sb_cfm": | ||
fm = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method="sinkhorn") | ||
elif method == "exact_ot_cfm": | ||
fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) | ||
elif method == "i_cfm": | ||
fm = ConditionalFlowMatcher(sigma=sigma) | ||
return fm | ||
|
||
|
||
def sample_plan(method, x0, x1, sigma): | ||
if method == "sb_cfm": | ||
x0, x1 = OTPlanSampler(method="sinkhorn", reg=2 * (sigma**2)).sample_plan(x0, x1) | ||
elif method == "exact_ot_cfm": | ||
x0, x1 = OTPlanSampler(method="exact").sample_plan(x0, x1) | ||
return x0, x1 | ||
|
||
|
||
@pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"]) | ||
# Test both integer and floating sigma | ||
@pytest.mark.parametrize("sigma", [0.0, 0.5, 1.5, 0, 1]) | ||
@pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]]) | ||
def test_fm(method, sigma, shape): | ||
batch_size = TEST_BATCH_SIZE | ||
|
||
if method in SIGMA_CONDITION.keys() and SIGMA_CONDITION[method](sigma): | ||
with pytest.raises(ValueError): | ||
get_flow_matcher(method, sigma) | ||
return | ||
|
||
FM = get_flow_matcher(method, sigma) | ||
x0, x1 = random_samples(shape, batch_size=batch_size) | ||
torch.manual_seed(TEST_SEED) | ||
np.random.seed(TEST_SEED) | ||
t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True) | ||
|
||
if method in ["sb_cfm", "exact_ot_cfm"]: | ||
torch.manual_seed(TEST_SEED) | ||
np.random.seed(TEST_SEED) | ||
x0, x1 = sample_plan(method, x0, x1, sigma) | ||
|
||
torch.manual_seed(TEST_SEED) | ||
t_given_init = torch.rand(batch_size) | ||
t_given = t_given_init.reshape(-1, *([1] * (x0.dim() - 1))) | ||
sigma_pad = pad_t_like_x(sigma, x0) | ||
epsilon = torch.randn_like(x0) | ||
computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, epsilon) | ||
|
||
assert torch.all(ut.eq(computed_ut)) | ||
assert torch.all(xt.eq(computed_xt)) | ||
assert torch.all(eps.eq(epsilon)) | ||
assert any(t_given_init == t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
|
||
from torchcfm.models import MLP | ||
from torchcfm.models.unet import UNetModel | ||
|
||
|
||
def test_initialize_models(): | ||
model = UNetModel( | ||
dim=(1, 28, 28), | ||
num_channels=32, | ||
num_res_blocks=1, | ||
num_classes=10, | ||
class_cond=True, | ||
) | ||
model = MLP(dim=2, time_varying=True, w=64) |
Oops, something went wrong.