Skip to content

Commit

Permalink
Merge branch 'main' into forest_flow
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras authored Nov 23, 2023
2 parents 3c7165f + 81fcb8d commit d3fe6f2
Show file tree
Hide file tree
Showing 20 changed files with 392 additions and 30 deletions.
16 changes: 9 additions & 7 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Tests
name: TorchCFM Tests

on:
push:
Expand All @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- name: Checkout
Expand All @@ -28,7 +28,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r runner-requirements.txt
pip install pytest
pip install sh
pip install -e .
Expand All @@ -39,10 +38,10 @@ jobs:
- name: Run pytest
run: |
pytest -v runner
pytest -v --ignore=examples --ignore=runner
# upload code coverage report
code-coverage:
code-coverage-torchcfm:
runs-on: ubuntu-latest

steps:
Expand All @@ -57,14 +56,17 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r runner-requirements.txt
pip install pytest
pip install pytest-cov[toml]
pip install sh
pip install -e .
- name: Run tests and collect coverage
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
name: codecov-torchcfm
verbose: true
fail_ci_if_error: true
74 changes: 74 additions & 0 deletions .github/workflows/test_runner.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: Runner Tests

on:
push:
branches: [main]
pull_request:
branches: [main, "release/*"]

jobs:
run_tests_ubuntu:
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10"]

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r runner-requirements.txt
pip install pytest
pip install sh
pip install -e .
- name: List dependencies
run: |
python -m pip list
- name: Run pytest
run: |
pytest -v runner
# upload code coverage report
code-coverage-runner:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r runner-requirements.txt
pip install pytest
pip install pytest-cov[toml]
pip install sh
pip install -e .
- name: Run tests and collect coverage
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
name: codecov-runner
verbose: true
fail_ci_if_error: true
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ count = true

[tool.bandit]
skips = ["B101", "B311"]

[tool.isort]
known_first_party = ["tests", "src"]
3 changes: 2 additions & 1 deletion runner/src/datamodules/distribution_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from pytorch_lightning import LightningDataModule
from pytorch_lightning.trainer.supporters import CombinedLoader
from sklearn.preprocessing import StandardScaler
from src import utils
from torch.utils.data import DataLoader, Sampler, TensorDataset, random_split
from torchdyn.datasets import ToyDataset

from src import utils

from .components.base import BaseLightningDataModule
from .components.time_dataset import load_dataset
from .components.tnet_dataset import SCData
Expand Down
1 change: 1 addition & 0 deletions runner/src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase

from src import utils

log = utils.get_pylogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions runner/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from omegaconf import DictConfig
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import LightningLoggerBase

from src import utils

log = utils.get_pylogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions runner/src/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.utilities import rank_zero_only
from rich.prompt import Prompt

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions runner/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning import Callback
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only

from src.utils import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions runner/tests/helpers/run_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from packaging.version import Version
from pkg_resources import get_distribution

from tests.helpers.package_available import (
_COMET_AVAILABLE,
_DEEPSPEED_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions runner/tests/helpers/run_sh_command.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

import pytest

from tests.helpers.package_available import _SH_AVAILABLE

if _SH_AVAILABLE:
Expand Down
1 change: 1 addition & 0 deletions runner/tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch

from src.datamodules.distribution_datamodule import (
SKLearnDataModule,
TorchDynDataModule,
Expand Down
1 change: 1 addition & 0 deletions runner/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from hydra.core.hydra_config import HydraConfig
from omegaconf import open_dict

from src.eval import evaluate
from src.train import train

Expand Down
1 change: 1 addition & 0 deletions runner/tests/test_sweeps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from tests.helpers.run_if import RunIf
from tests.helpers.run_sh_command import run_sh_command

Expand Down
1 change: 1 addition & 0 deletions runner/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from hydra.core.hydra_config import HydraConfig
from omegaconf import open_dict

from src.train import train
from tests.helpers.run_if import RunIf

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@
license="MIT",
long_description=readme,
long_description_content_type="text/markdown",
packages=find_packages(),
packages=find_packages(exclude=["tests", "tests.*"]),
extras_require={"forest-flow": ["xgboost", "scikit-learn", "ForestDiffusion"]},
)
)
126 changes: 126 additions & 0 deletions tests/test_conditional_flow_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Tests for Conditional Flow Matcher classers."""

# Author: Kilian Fatras <[email protected]>

import math

import numpy as np
import pytest
import torch

from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
SchrodingerBridgeConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
pad_t_like_x,
)
from torchcfm.optimal_transport import OTPlanSampler

TEST_SEED = 1994
TEST_BATCH_SIZE = 128
SIGMA_CONDITION = {
"sb_cfm": lambda x: x <= 0,
}


def random_samples(shape, batch_size=TEST_BATCH_SIZE):
"""Generate random samples of different dimensions."""
if isinstance(shape, int):
shape = [shape]
return [torch.randn(batch_size, *shape), torch.randn(batch_size, *shape)]


def compute_xt_ut(method, x0, x1, t_given, sigma, epsilon):
if method == "vp_cfm":
sigma_t = sigma
mu_t = torch.cos(math.pi / 2 * t_given) * x0 + torch.sin(math.pi / 2 * t_given) * x1
computed_xt = mu_t + sigma_t * epsilon
computed_ut = (
math.pi
/ 2
* (torch.cos(math.pi / 2 * t_given) * x1 - torch.sin(math.pi / 2 * t_given) * x0)
)
elif method == "t_cfm":
sigma_t = 1 - (1 - sigma) * t_given
mu_t = t_given * x1
computed_xt = mu_t + sigma_t * epsilon
computed_ut = (x1 - (1 - sigma) * computed_xt) / sigma_t

elif method == "sb_cfm":
sigma_t = sigma * torch.sqrt(t_given * (1 - t_given))
mu_t = t_given * x1 + (1 - t_given) * x0
computed_xt = mu_t + sigma_t * epsilon
computed_ut = (
(1 - 2 * t_given)
/ (2 * t_given * (1 - t_given) + 1e-8)
* (computed_xt - (t_given * x1 + (1 - t_given) * x0))
+ x1
- x0
)
elif method in ["exact_ot_cfm", "i_cfm"]:
sigma_t = sigma
mu_t = t_given * x1 + (1 - t_given) * x0
computed_xt = mu_t + sigma_t * epsilon
computed_ut = x1 - x0

return computed_xt, computed_ut


def get_flow_matcher(method, sigma):
if method == "vp_cfm":
fm = VariancePreservingConditionalFlowMatcher(sigma=sigma)
elif method == "t_cfm":
fm = TargetConditionalFlowMatcher(sigma=sigma)
elif method == "sb_cfm":
fm = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method="sinkhorn")
elif method == "exact_ot_cfm":
fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif method == "i_cfm":
fm = ConditionalFlowMatcher(sigma=sigma)
return fm


def sample_plan(method, x0, x1, sigma):
if method == "sb_cfm":
x0, x1 = OTPlanSampler(method="sinkhorn", reg=2 * (sigma**2)).sample_plan(x0, x1)
elif method == "exact_ot_cfm":
x0, x1 = OTPlanSampler(method="exact").sample_plan(x0, x1)
return x0, x1


@pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"])
# Test both integer and floating sigma
@pytest.mark.parametrize("sigma", [0.0, 0.5, 1.5, 0, 1])
@pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]])
def test_fm(method, sigma, shape):
batch_size = TEST_BATCH_SIZE

if method in SIGMA_CONDITION.keys() and SIGMA_CONDITION[method](sigma):
with pytest.raises(ValueError):
get_flow_matcher(method, sigma)
return

FM = get_flow_matcher(method, sigma)
x0, x1 = random_samples(shape, batch_size=batch_size)
torch.manual_seed(TEST_SEED)
np.random.seed(TEST_SEED)
t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)

if method in ["sb_cfm", "exact_ot_cfm"]:
torch.manual_seed(TEST_SEED)
np.random.seed(TEST_SEED)
x0, x1 = sample_plan(method, x0, x1, sigma)

torch.manual_seed(TEST_SEED)
t_given_init = torch.rand(batch_size)
t_given = t_given_init.reshape(-1, *([1] * (x0.dim() - 1)))
sigma_pad = pad_t_like_x(sigma, x0)
epsilon = torch.randn_like(x0)
computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, epsilon)

assert torch.all(ut.eq(computed_ut))
assert torch.all(xt.eq(computed_xt))
assert torch.all(eps.eq(epsilon))
assert any(t_given_init == t)
15 changes: 15 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from torchcfm.models import MLP
from torchcfm.models.unet import UNetModel


def test_initialize_models():
model = UNetModel(
dim=(1, 28, 28),
num_channels=32,
num_res_blocks=1,
num_classes=10,
class_cond=True,
)
model = MLP(dim=2, time_varying=True, w=64)
Loading

0 comments on commit d3fe6f2

Please sign in to comment.