diff --git a/README.md b/README.md index af8d1f1..2bc07b6 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,12 @@ TODO: FIX PREPRINT BUTTON AFTER WE'RE ON ARXIV!!! ## Description -This is the official repository for the paper [Iterated Denoising Energy Matching for Sampling from Boltzmann Densities](https://arxiv.org/abs/2310.02391) (TODO: FIX THIS LINK AFTER WE'RE ON ARXIV). + +This is the official repository for the paper [Iterated Denoising Energy Matching for Sampling from Boltzmann Densities](https://arxiv.org/abs/2310.02391) (TODO: FIX THIS LINK AFTER WE'RE ON ARXIV). We propose iDEM, a scalable and efficient method to sample from unnormalized probability distributions. iDEM makes use of the DEM objective, inspired by the stochastic regression and simulation free principles of score and flow matching objectives while allowing one to learn off-policy, in a loop while itself generating (optionally exploratory) new states which are subsequently -learned on. iDEM is also capable of incorporating symmetries, namely those represented by the product group of $SE(3) \times \mathbb{S}_n$. We experiment on a 2D GMM task as well as a number of physics +learned on. iDEM is also capable of incorporating symmetries, namely those represented by the product group of $SE(3) \\times \\mathbb{S}\_n$. We experiment on a 2D GMM task as well as a number of physics inspired problems. These include: - DW4 -- the 4 particle double well potential (8 dimensions total) @@ -29,13 +30,14 @@ out most of the code and experiments with help from [@sarthmit](https://github.c For installation we recommend the use of Micromamba. Please refer [here](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html) for an installation guide for Micromamba. First, we install dependencies + ```bash # clone project git clone git@github.com:jarridrb/DEM.git cd DEM # create micromamba environment -micromamba create -f environment.yaml +micromamba create -f environment.yaml micromamba activate dem # install requirements @@ -48,6 +50,7 @@ an example `.env.example` file for convenience. Note that to use wandb we requir `.env` file. To run an experiment, e.g., GMM with iDEM, you can run on the command line + ```bash python dem/train.py experiment=gmm_idem ``` @@ -55,18 +58,21 @@ python dem/train.py experiment=gmm_idem We include configs for all experiments matching the settings we used in our paper for both iDEM and pDEM with the exception of LJ55 for which we only include a config for iDEM and pDEM had convergence issues on LJ55. -## Current Code +## Current Code + The current repository contains code for experiments for iDEM and pDEM as specified in our paper. ## Planned Updates + - [ ] Code to do Langevin on top of generated samples ## Citations + If this codebase is useful towards other research efforts please consider citing us. TODO: FIX THIS CITATION ONCE WE'RE ON ARXIV!!! ``` @misc{bose2023se3stochastic, - title={SE(3)-Stochastic Flow Matching for Protein Backbone Generation}, + title={SE(3)-Stochastic Flow Matching for Protein Backbone Generation}, author={Avishek Joey Bose and Tara Akhound-Sadegh and Kilian Fatras and Guillaume Huguet and Jarrid Rector-Brooks and Cheng-Hao Liu and Andrei Cristian Nica and Maksym Korablyov and Michael Bronstein and Alexander Tong}, year={2023}, eprint={2310.02391}, @@ -75,14 +81,13 @@ If this codebase is useful towards other research efforts please consider citing } ``` - ## Contribute We welcome issues and pull requests (especially bug fixes) and contributions. We will try our best to improve readability and answer questions! - ## Licences + This repo is licensed under the [MIT License](https://opensource.org/license/mit/). ### Warning: the current code uses PyTorch 2.0.0+ diff --git a/configs/energy/dw4.yaml b/configs/energy/dw4.yaml index 36f38b5..ff45200 100644 --- a/configs/energy/dw4.yaml +++ b/configs/energy/dw4.yaml @@ -1,19 +1,17 @@ - _target_: dem.energies.multi_double_well_energy.MultiDoubleWellEnergy +_target_: dem.energies.multi_double_well_energy.MultiDoubleWellEnergy - dimensionality: 8 - n_particles: 4 +dimensionality: 8 +n_particles: 4 - data_from_efm: true - data_path: "data/test_split_DW4.npy" - data_path_train: "data/train_split_DW4.npy" - data_path_val: "data/val_split_DW4.npy" +data_from_efm: true +data_path: "data/test_split_DW4.npy" +data_path_train: "data/train_split_DW4.npy" +data_path_val: "data/val_split_DW4.npy" - device: ${trainer.accelerator} +device: ${trainer.accelerator} - plot_samples_epoch_period: 1 +plot_samples_epoch_period: 1 - data_normalization_factor: 1.0 +data_normalization_factor: 1.0 - is_molecule: true - - +is_molecule: true diff --git a/configs/energy/gmm.yaml b/configs/energy/gmm.yaml index aab02ed..ed9962b 100644 --- a/configs/energy/gmm.yaml +++ b/configs/energy/gmm.yaml @@ -1,13 +1,13 @@ - _target_: dem.energies.gmm_energy.GMM +_target_: dem.energies.gmm_energy.GMM - dimensionality: 2 - n_mixes: 40 - loc_scaling: 40 - log_var_scaling: 1.0 +dimensionality: 2 +n_mixes: 40 +loc_scaling: 40 +log_var_scaling: 1.0 - device: ${trainer.accelerator} +device: ${trainer.accelerator} - plot_samples_epoch_period: 1 +plot_samples_epoch_period: 1 - should_unnormalize: true - data_normalization_factor: 50 +should_unnormalize: true +data_normalization_factor: 50 diff --git a/configs/energy/lj13.yaml b/configs/energy/lj13.yaml index 88ea3f6..f0caad7 100644 --- a/configs/energy/lj13.yaml +++ b/configs/energy/lj13.yaml @@ -1,15 +1,15 @@ - _target_: dem.energies.lennardjones_energy.LennardJonesEnergy +_target_: dem.energies.lennardjones_energy.LennardJonesEnergy - dimensionality: 39 - n_particles: 13 - data_path: "data/test_split_LJ13-1000.npy" - data_path_train: "data/train_split_LJ13-1000.npy" - data_path_val: "data/test_split_LJ13-1000.npy" +dimensionality: 39 +n_particles: 13 +data_path: "data/test_split_LJ13-1000.npy" +data_path_train: "data/train_split_LJ13-1000.npy" +data_path_val: "data/test_split_LJ13-1000.npy" - device: ${trainer.accelerator} +device: ${trainer.accelerator} - plot_samples_epoch_period: 1 +plot_samples_epoch_period: 1 - data_normalization_factor: 1.0 +data_normalization_factor: 1.0 - is_molecule: True +is_molecule: True diff --git a/configs/energy/lj55.yaml b/configs/energy/lj55.yaml index 76b409d..1ef832a 100644 --- a/configs/energy/lj55.yaml +++ b/configs/energy/lj55.yaml @@ -1,18 +1,15 @@ - _target_: dem.energies.lennardjones_energy.LennardJonesEnergy +_target_: dem.energies.lennardjones_energy.LennardJonesEnergy - dimensionality: 165 - n_particles: 55 - data_path: "data/test_split_LJ55-1000-part1.npy" - data_path_train: "data/train_split_LJ55-1000-part1.npy" - data_path_val: "data/val_split_LJ55-1000-part1.npy" +dimensionality: 165 +n_particles: 55 +data_path: "data/test_split_LJ55-1000-part1.npy" +data_path_train: "data/train_split_LJ55-1000-part1.npy" +data_path_val: "data/val_split_LJ55-1000-part1.npy" - device: ${trainer.accelerator} +device: ${trainer.accelerator} - plot_samples_epoch_period: 1 +plot_samples_epoch_period: 1 - data_normalization_factor: 1.0 - - is_molecule: True - - +data_normalization_factor: 1.0 +is_molecule: True diff --git a/configs/experiment/dw4_idem.yaml b/configs/experiment/dw4_idem.yaml index 8dad845..a8daf21 100644 --- a/configs/experiment/dw4_idem.yaml +++ b/configs/experiment/dw4_idem.yaml @@ -15,7 +15,7 @@ logger: tags: ${tags} group: "dw4_efm" -defaults: +defaults: - override /energy: dw4 - override /model/net: egnn @@ -31,8 +31,8 @@ model: noise_schedule: _target_: dem.models.components.noise_schedules.GeometricNoiseSchedule - sigma_min: 0.00001 - sigma_max: 3 + sigma_min: 0.00001 + sigma_max: 3 partial_prior: _target_: dem.energies.base_prior.MeanFreePrior @@ -48,13 +48,13 @@ model: _target_: dem.models.components.clipper.Clipper should_clip_scores: True should_clip_log_rewards: False - max_score_norm: 20 + max_score_norm: 20 min_log_reward: null # num_samples_to_sample_from_buffer: 5120 diffusion_scale: 1 num_samples_to_generate_per_epoch: 1000 - + init_from_prior: true eval_batch_size: 1000 diff --git a/configs/experiment/dw4_pdem.yaml b/configs/experiment/dw4_pdem.yaml index 3f13503..d1bf040 100644 --- a/configs/experiment/dw4_pdem.yaml +++ b/configs/experiment/dw4_pdem.yaml @@ -15,7 +15,7 @@ logger: tags: ${tags} group: "dw4_efm" -defaults: +defaults: - override /energy: dw4 - override /model/net: egnn @@ -31,8 +31,8 @@ model: noise_schedule: _target_: dem.models.components.noise_schedules.GeometricNoiseSchedule - sigma_min: 0.00001 - sigma_max: 3 + sigma_min: 0.00001 + sigma_max: 3 partial_prior: _target_: dem.energies.base_prior.MeanFreePrior @@ -48,13 +48,13 @@ model: _target_: dem.models.components.clipper.Clipper should_clip_scores: True should_clip_log_rewards: False - max_score_norm: 20 + max_score_norm: 20 min_log_reward: null # num_samples_to_sample_from_buffer: 5120 diffusion_scale: 1 num_samples_to_generate_per_epoch: 1000 - + init_from_prior: true eval_batch_size: 1000 diff --git a/configs/experiment/gmm_idem.yaml b/configs/experiment/gmm_idem.yaml index 88fc386..dcc1d60 100644 --- a/configs/experiment/gmm_idem.yaml +++ b/configs/experiment/gmm_idem.yaml @@ -5,10 +5,9 @@ # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -defaults: +defaults: - override /energy: gmm - tags: ["GMM", "iDEM"] seed: 12345 @@ -60,4 +59,3 @@ model: trainer: check_val_every_n_epoch: 5 - diff --git a/configs/experiment/gmm_pdem.yaml b/configs/experiment/gmm_pdem.yaml index 1725579..2a02769 100644 --- a/configs/experiment/gmm_pdem.yaml +++ b/configs/experiment/gmm_pdem.yaml @@ -5,10 +5,9 @@ # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -defaults: +defaults: - override /energy: gmm - tags: ["GMM", "pDEM"] seed: 12345 diff --git a/configs/experiment/lj13_idem.yaml b/configs/experiment/lj13_idem.yaml index 5895ea1..c87479f 100644 --- a/configs/experiment/lj13_idem.yaml +++ b/configs/experiment/lj13_idem.yaml @@ -15,7 +15,7 @@ logger: tags: ${tags} group: "lj13" -defaults: +defaults: - override /energy: lj13 - override /model/net: egnn @@ -27,7 +27,7 @@ model: noise_schedule: _target_: dem.models.components.noise_schedules.GeometricNoiseSchedule sigma_min: 0.01 - sigma_max: 2 + sigma_max: 2 partial_prior: _target_: dem.energies.base_prior.MeanFreePrior @@ -50,7 +50,7 @@ model: diffusion_scale: 0.9 num_samples_to_generate_per_epoch: 1000 num_samples_to_sample_from_buffer: 512 - + init_from_prior: true cfm_prior_std: 2 diff --git a/configs/experiment/lj13_pdem.yaml b/configs/experiment/lj13_pdem.yaml index 06fe4cb..c130fff 100644 --- a/configs/experiment/lj13_pdem.yaml +++ b/configs/experiment/lj13_pdem.yaml @@ -15,7 +15,7 @@ logger: tags: ${tags} group: "lj13" -defaults: +defaults: - override /energy: lj13 - override /model/net: egnn @@ -27,7 +27,7 @@ model: noise_schedule: _target_: dem.models.components.noise_schedules.GeometricNoiseSchedule sigma_min: 0.01 - sigma_max: 2 + sigma_max: 2 partial_prior: _target_: dem.energies.base_prior.MeanFreePrior @@ -50,7 +50,7 @@ model: diffusion_scale: 0.9 num_samples_to_generate_per_epoch: 1000 num_samples_to_sample_from_buffer: 512 - + init_from_prior: true cfm_prior_std: 2 diff --git a/configs/experiment/lj55_idem.yaml b/configs/experiment/lj55_idem.yaml index 4ae2884..51f12f5 100644 --- a/configs/experiment/lj55_idem.yaml +++ b/configs/experiment/lj55_idem.yaml @@ -15,7 +15,7 @@ logger: tags: ${tags} group: "lj55" -defaults: +defaults: - override /energy: lj55 - override /model/net: egnn @@ -66,5 +66,3 @@ model: logz_with_cfm: true nll_integration_method: dopri5 - - diff --git a/configs/model/dem.yaml b/configs/model/dem.yaml index 366923a..13d14bc 100644 --- a/configs/model/dem.yaml +++ b/configs/model/dem.yaml @@ -15,9 +15,9 @@ scheduler: defaults: - net: - - mlp + - mlp - noise_schedule: - - geometric + - geometric buffer: _target_: dem.models.components.prioritised_replay_buffer.SimpleBuffer @@ -79,4 +79,3 @@ logz_with_cfm: false num_samples_to_save: 100000 tol: 1e-5 - diff --git a/configs/model/net/egnn.yaml b/configs/model/net/egnn.yaml index a74c7c8..a74652f 100644 --- a/configs/model/net/egnn.yaml +++ b/configs/model/net/egnn.yaml @@ -4,7 +4,7 @@ n_particles: 13 n_dimension: 3 hidden_nf: 32 n_layers: 3 -act_fn: +act_fn: _target_: torch.nn.SiLU recurrent: True tanh: True diff --git a/configs/model/net/pis_mlp.yaml b/configs/model/net/pis_mlp.yaml index c866f58..8cb8a72 100644 --- a/configs/model/net/pis_mlp.yaml +++ b/configs/model/net/pis_mlp.yaml @@ -3,4 +3,4 @@ _partial_: true num_layers: 2 channels: 64 in_shape: ${energy.dimensionality} -out_shape: ${energy.dimensionality} \ No newline at end of file +out_shape: ${energy.dimensionality} diff --git a/configs/model/pis.yaml b/configs/model/pis.yaml index 86f853f..cbd258c 100644 --- a/configs/model/pis.yaml +++ b/configs/model/pis.yaml @@ -30,7 +30,7 @@ tnet: defaults: - noise_schedule: - - geometric + - geometric buffer: _target_: dem.models.components.prioritised_replay_buffer.SimpleBuffer @@ -71,4 +71,4 @@ use_ema: false debug_use_train_data: false pis_scale: 1. -time_range: 1. \ No newline at end of file +time_range: 1. diff --git a/dem/data/dummy.py b/dem/data/dummy.py index 35c9888..9af5655 100644 --- a/dem/data/dummy.py +++ b/dem/data/dummy.py @@ -18,9 +18,7 @@ def __init__( self.batch_size = batch_size def get_dataloader(self, size): - return DataLoader( - np.arange(size * self.batch_size)[:, None], batch_size=self.batch_size - ) + return DataLoader(np.arange(size * self.batch_size)[:, None], batch_size=self.batch_size) def train_dataloader(self): return self.get_dataloader(self.n_train_batches_per_epoch) diff --git a/dem/data/mnist_datamodule.py b/dem/data/mnist_datamodule.py index c4e5ba5..88879ca 100644 --- a/dem/data/mnist_datamodule.py +++ b/dem/data/mnist_datamodule.py @@ -120,18 +120,12 @@ def setup(self, stage: Optional[str] = None) -> None: raise RuntimeError( f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." ) - self.batch_size_per_device = ( - self.hparams.batch_size // self.trainer.world_size - ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size # load and split datasets only if not loaded already if not self.data_train and not self.data_val and not self.data_test: - trainset = MNIST( - self.hparams.data_dir, train=True, transform=self.transforms - ) - testset = MNIST( - self.hparams.data_dir, train=False, transform=self.transforms - ) + trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) + testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) dataset = ConcatDataset(datasets=[trainset, testset]) self.data_train, self.data_val, self.data_test = random_split( dataset=dataset, diff --git a/dem/energies/base_energy_function.py b/dem/energies/base_energy_function.py index 5347e82..82982c9 100644 --- a/dem/energies/base_energy_function.py +++ b/dem/energies/base_energy_function.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional -import torch import numpy as np +import torch from pytorch_lightning.loggers import WandbLogger from dem.models.components.replay_buffer import ReplayBuffer @@ -62,9 +62,7 @@ def unnormalize(self, x: torch.Tensor) -> torch.Tensor: x = (x + 1) / 2 return x * (maxs - mins) + mins - def sample_test_set( - self, num_points: int, normalize: bool = False - ) -> Optional[torch.Tensor]: + def sample_test_set(self, num_points: int, normalize: bool = False) -> Optional[torch.Tensor]: if self.test_set is None: return None @@ -75,9 +73,7 @@ def sample_test_set( return outs - def sample_train_set( - self, num_points: int, normalize: bool = False - ) -> Optional[torch.Tensor]: + def sample_train_set(self, num_points: int, normalize: bool = False) -> Optional[torch.Tensor]: if self.train_set is None: self._train_set = self.setup_train_set() @@ -88,9 +84,7 @@ def sample_train_set( return outs - def sample_val_set( - self, num_points: int, normalize: bool = False - ) -> Optional[torch.Tensor]: + def sample_val_set(self, num_points: int, normalize: bool = False) -> Optional[torch.Tensor]: if self.val_set is None: return None diff --git a/dem/energies/base_prior.py b/dem/energies/base_prior.py index c62452a..9ebef79 100644 --- a/dem/energies/base_prior.py +++ b/dem/energies/base_prior.py @@ -1,6 +1,7 @@ import math -import torch from typing import Dict + +import torch from torch.distributions import constraints @@ -24,7 +25,7 @@ class MeanFreePrior(torch.distributions.Distribution): arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, n_particles, spatial_dim, scale, device="cpu"): - super(MeanFreePrior, self).__init__() + super().__init__() self.n_particles = n_particles self.spatial_dim = spatial_dim self.dim = n_particles * spatial_dim diff --git a/dem/energies/gmm_energy.py b/dem/energies/gmm_energy.py index 2297e13..13b2d83 100644 --- a/dem/energies/gmm_energy.py +++ b/dem/energies/gmm_energy.py @@ -51,7 +51,7 @@ def __init__( self.test_set_size = test_set_size self.val_set_size = val_set_size - self.name="gmm" + self.name = "gmm" super().__init__( dimensionality=dimensionality, @@ -87,7 +87,7 @@ def log_on_epoch_end( latest_samples: torch.Tensor, latest_energies: torch.Tensor, wandb_logger: WandbLogger, - unprioritized_buffer_samples = None, + unprioritized_buffer_samples=None, cfm_samples=None, replay_buffer=None, prefix: str = "", @@ -99,11 +99,8 @@ def log_on_epoch_end( prefix += "/" if self.curr_epoch % self.plot_samples_epoch_period == 0: - if unprioritized_buffer_samples is not None: - buffer_samples, _, _ = replay_buffer.sample( - self.plotting_buffer_sample_size - ) + buffer_samples, _, _ = replay_buffer.sample(self.plotting_buffer_sample_size) if self.should_unnormalize: # Don't unnormalize CFM samples since they're in the @@ -111,32 +108,23 @@ def log_on_epoch_end( buffer_samples = self.unnormalize(buffer_samples) latest_samples = self.unnormalize(latest_samples) - unprioritized_buffer_samples = self.unnormalize( - unprioritized_buffer_samples - ) + unprioritized_buffer_samples = self.unnormalize(unprioritized_buffer_samples) samples_fig = self.get_dataset_fig(buffer_samples, latest_samples) wandb_logger.log_image(f"{prefix}generated_samples", [samples_fig]) if cfm_samples is not None: - cfm_samples_fig = self.get_dataset_fig( - unprioritized_buffer_samples, cfm_samples - ) + cfm_samples_fig = self.get_dataset_fig(unprioritized_buffer_samples, cfm_samples) - wandb_logger.log_image( - f"{prefix}cfm_generated_samples", [cfm_samples_fig] - ) + wandb_logger.log_image(f"{prefix}cfm_generated_samples", [cfm_samples_fig]) if latest_samples is not None: fig, ax = plt.subplots() ax.scatter(*latest_samples.detach().cpu().T) - - wandb_logger.log_image( - f"{prefix}generated_samples_scatter", [fig_to_image(fig)] - ) - self.get_single_dataset_fig(latest_samples, "dem_generated_samples") + wandb_logger.log_image(f"{prefix}generated_samples_scatter", [fig_to_image(fig)]) + self.get_single_dataset_fig(latest_samples, "dem_generated_samples") self.curr_epoch += 1 @@ -155,9 +143,7 @@ def log_samples( samples_fig = self.get_single_dataset_fig(samples, name) wandb_logger.log_image(f"{name}", [samples_fig]) - def get_single_dataset_fig( - self, samples, name, plotting_bounds=(-1.4 * 40, 1.4 * 40) - ): + def get_single_dataset_fig(self, samples, name, plotting_bounds=(-1.4 * 40, 1.4 * 40)): fig, ax = plt.subplots(1, 1, figsize=(8, 8)) self.gmm.to("cpu") @@ -176,9 +162,7 @@ def get_single_dataset_fig( return fig_to_image(fig) - def get_dataset_fig( - self, samples, gen_samples=None, plotting_bounds=(-1.4 * 40, 1.4 * 40) - ): + def get_dataset_fig(self, samples, gen_samples=None, plotting_bounds=(-1.4 * 40, 1.4 * 40)): fig, axs = plt.subplots(1, 2, figsize=(12, 4)) self.gmm.to("cpu") diff --git a/dem/energies/lennardjones_energy.py b/dem/energies/lennardjones_energy.py index 290112d..e8a83a1 100644 --- a/dem/energies/lennardjones_energy.py +++ b/dem/energies/lennardjones_energy.py @@ -52,7 +52,7 @@ def __init__( two_event_dims=True, energy_factor=1.0, ): - """Energy for a Lennard-Jones cluster + """Energy for a Lennard-Jones cluster. Parameters ---------- @@ -98,14 +98,10 @@ def _energy(self, x): lj_energies = lennard_jones_energy_torch(dists, self._eps, self._rm) # lj_energies = torch.clip(lj_energies, -1e4, 1e4) - lj_energies = ( - lj_energies.view(*batch_shape, -1).sum(dim=-1) * self._energy_factor - ) + lj_energies = lj_energies.view(*batch_shape, -1).sum(dim=-1) * self._energy_factor if self.oscillator: - osc_energies = 0.5 * self._remove_mean(x).pow(2).sum(dim=(-2, -1)).view( - *batch_shape - ) + osc_energies = 0.5 * self._remove_mean(x).pow(2).sum(dim=(-2, -1)).view(*batch_shape) lj_energies = lj_energies + osc_energies * self._oscillator_scale return lj_energies[:, None] @@ -211,8 +207,7 @@ def interatomic_dist(self, x): distances = x[:, None, :, :] - x[:, :, None, :] distances = distances[ :, - torch.triu(torch.ones((self.n_particles, self.n_particles)), diagonal=1) - == 1, + torch.triu(torch.ones((self.n_particles, self.n_particles)), diagonal=1) == 1, ] dist = torch.linalg.norm(distances, dim=-1) return dist @@ -222,7 +217,7 @@ def log_on_epoch_end( latest_samples: torch.Tensor, latest_energies: torch.Tensor, wandb_logger: WandbLogger, - unprioritized_buffer_samples = None, + unprioritized_buffer_samples=None, cfm_samples=None, replay_buffer=None, prefix: str = "", @@ -244,9 +239,7 @@ def log_on_epoch_end( if cfm_samples is not None: cfm_samples_fig = self.get_dataset_fig(cfm_samples) - wandb_logger.log_image( - f"{prefix}cfm_generated_samples", [cfm_samples_fig] - ) + wandb_logger.log_image(f"{prefix}cfm_generated_samples", [cfm_samples_fig]) self.curr_epoch += 1 @@ -328,6 +321,4 @@ def get_dataset_fig(self, samples): axs[1].legend() fig.canvas.draw() - return PIL.Image.frombytes( - "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() - ) + return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) diff --git a/dem/energies/lennardjones_energy_eacf.py b/dem/energies/lennardjones_energy_eacf.py index e7fe1db..778da14 100644 --- a/dem/energies/lennardjones_energy_eacf.py +++ b/dem/energies/lennardjones_energy_eacf.py @@ -56,18 +56,14 @@ def _get_senders_and_receivers_fully_connected(self, n_nodes): def _energy(self, x: torch.Tensor): if isinstance(self.rm, float): r = torch.ones(self._n_particles, device=self.device) * self.rm - senders, receivers = self._get_senders_and_receivers_fully_connected( - self._n_particles - ) + senders, receivers = self._get_senders_and_receivers_fully_connected(self._n_particles) vectors = x[senders] - x[receivers] d = torch.linalg.norm(vectors, ord=2, dim=-1) term_inside_sum = (r[receivers] / d) ** 12 - 2 * (r[receivers] / d) ** 6 energy = self.eps / (2 * self.tau) * term_inside_sum.sum() centre_of_mass = x.mean(dim=0) - harmonic_potential = ( - self.harmonic_potential_coef * (x - centre_of_mass).pow(2).sum() - ) + harmonic_potential = self.harmonic_potential_coef * (x - centre_of_mass).pow(2).sum() return energy + harmonic_potential def _log_prob(self, x: torch.Tensor): @@ -160,8 +156,7 @@ def interatomic_dist(self, x): distances = x[:, None, :, :] - x[:, :, None, :] distances = distances[ :, - torch.triu(torch.ones((self.n_particles, self.n_particles)), diagonal=1) - == 1, + torch.triu(torch.ones((self.n_particles, self.n_particles)), diagonal=1) == 1, ] dist = torch.linalg.norm(distances, dim=-1) return dist @@ -193,9 +188,7 @@ def log_on_epoch_end( if unprioritized_buffer_samples is not None: cfm_samples_fig = self.get_dataset_fig(cfm_samples) - wandb_logger.log_image( - f"{prefix}cfm_generated_samples", [cfm_samples_fig] - ) + wandb_logger.log_image(f"{prefix}cfm_generated_samples", [cfm_samples_fig]) self.curr_epoch += 1 @@ -272,6 +265,4 @@ def get_dataset_fig(self, samples): # plt.show() fig.canvas.draw() - return PIL.Image.frombytes( - "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() - ) + return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) diff --git a/dem/energies/multi_double_well_energy.py b/dem/energies/multi_double_well_energy.py index cf0d903..d6bb7da 100644 --- a/dem/energies/multi_double_well_energy.py +++ b/dem/energies/multi_double_well_energy.py @@ -115,9 +115,7 @@ def setup_val_set(self): else: all_data = np.load(self.data_path, allow_pickle=True) - data = all_data[0][ - -self.test_set_size - self.val_set_size : -self.test_set_size - ] + data = all_data[0][-self.test_set_size - self.val_set_size : -self.test_set_size] del all_data data = remove_mean(torch.tensor(data), self.n_particles, self.n_spatial_dim).to( @@ -134,8 +132,7 @@ def interatomic_dist(self, x): distances = x[:, None, :, :] - x[:, :, None, :] distances = distances[ :, - torch.triu(torch.ones((self.n_particles, self.n_particles)), diagonal=1) - == 1, + torch.triu(torch.ones((self.n_particles, self.n_particles)), diagonal=1) == 1, ] dist = torch.linalg.norm(distances, dim=-1) return dist @@ -158,7 +155,7 @@ def log_on_epoch_end( latest_samples: torch.Tensor, latest_energies: torch.Tensor, wandb_logger: WandbLogger, - unprioritized_buffer_samples = None, + unprioritized_buffer_samples=None, cfm_samples=None, replay_buffer=None, prefix: str = "", @@ -180,9 +177,7 @@ def log_on_epoch_end( if unprioritized_buffer_samples is not None: cfm_samples_fig = self.get_dataset_fig(cfm_samples) - wandb_logger.log_image( - f"{prefix}cfm_generated_samples", [cfm_samples_fig] - ) + wandb_logger.log_image(f"{prefix}cfm_generated_samples", [cfm_samples_fig]) self.curr_epoch += 1 @@ -245,6 +240,4 @@ def get_dataset_fig(self, samples): axs[1].legend() fig.canvas.draw() - return PIL.Image.frombytes( - "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() - ) + return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) diff --git a/dem/models/cfm_module.py b/dem/models/cfm_module.py index 1adacb8..c33b271 100644 --- a/dem/models/cfm_module.py +++ b/dem/models/cfm_module.py @@ -53,10 +53,7 @@ def t_stratified_loss(batch_t, batch_loss, num_bins=5, loss_name=None): def get_wandb_logger(loggers): - """ - Gets the wandb logger if it is the - list of loggers otherwise returns None. - """ + """Gets the wandb logger if it is the list of loggers otherwise returns None.""" wandb_logger = None for logger in loggers: if isinstance(logger, WandbLogger): @@ -182,9 +179,7 @@ def get_cfm_loss(self, samples: torch.Tensor) -> torch.Tensor: x1 = samples x1 = self.energy_function.unnormalize(x1) - t, xt, ut = self.conditional_flow_matcher.sample_location_and_conditional_flow( - x0, x1 - ) + t, xt, ut = self.conditional_flow_matcher.sample_location_and_conditional_flow(x0, x1) if self.energy_function.is_molecule and self.cfm_sigma != 0: xt = remove_mean( @@ -235,9 +230,7 @@ def on_train_epoch_end(self) -> None: ) self.last_energies = self.energy_function(self.last_samples) else: - self.last_samples = self.generate_samples( - diffusion_scale=self.diffusion_scale - ) + self.last_samples = self.generate_samples(diffusion_scale=self.diffusion_scale) self.last_energies = self.energy_function(self.last_samples) self.buffer.add(self.last_samples, self.last_energies) @@ -253,9 +246,7 @@ def _log_energy_w2(self): _, generated_energies = self.buffer.get_last_n_inserted(self.eval_batch_size) - energy_w2 = pot.emd2_1d( - val_energies.cpu().numpy(), generated_energies.cpu().numpy() - ) + energy_w2 = pot.emd2_1d(val_energies.cpu().numpy(), generated_energies.cpu().numpy()) self.log( f"val/energy_w2", @@ -266,9 +257,7 @@ def _log_energy_w2(self): ) def compute_log_z(self, cnf, prior, samples, prefix, name): - nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll( - cnf, prior, samples - ) + nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll(cnf, prior, samples) # this is stupid we should fix the normalization in the energy function logz = self.energy_function(self.energy_function.normalize(samples)) + nll logz_metric = getattr(self, f"{prefix}_{name}logz") @@ -283,9 +272,7 @@ def compute_log_z(self, cnf, prior, samples, prefix, name): def compute_and_log_nll(self, cnf, prior, samples, prefix, name): cnf.nfe = 0.0 - nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll( - cnf, prior, samples - ) + nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll(cnf, prior, samples) nfe_metric = getattr(self, f"{prefix}_{name}nfe") nll_metric = getattr(self, f"{prefix}_{name}nll") logdetjac_metric = getattr(self, f"{prefix}_{name}nll_logdetjac") @@ -336,24 +323,16 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: # update and log metrics loss_metric = self.val_loss if prefix == "val" else self.test_loss loss_metric(loss) - self.log( - f"{prefix}/loss", loss_metric, on_step=True, on_epoch=True, prog_bar=True - ) + self.log(f"{prefix}/loss", loss_metric, on_step=True, on_epoch=True, prog_bar=True) forwards_samples = self.compute_and_log_nll( self.cfm_cnf, self.cfm_prior, batch, prefix, "" ) backwards_samples = self.cfm_cnf.generate(x1)[-1] - to_log = { - "data_0": batch, - "gen_0": backwards_samples, - "gen_1_cfm": forwards_samples - } + to_log = {"data_0": batch, "gen_0": backwards_samples, "gen_1_cfm": forwards_samples} iter_samples, _, _ = self.buffer.sample(self.eval_batch_size) # backwards_samples = self.generate_cfm_samples(self.eval_batch_size) - self.compute_log_z( - self.cfm_cnf, self.cfm_prior, backwards_samples, prefix, "" - ) + self.compute_log_z(self.cfm_cnf, self.cfm_prior, backwards_samples, prefix, "") self.eval_step_outputs.append(to_log) @@ -361,7 +340,6 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch = self.energy_function.sample_val_set(self.eval_batch_size) self.eval_step("val", batch, batch_idx) - def test_step(self, batch: torch.Tensor, batch_idx: int) -> None: batch = self.energy_function.sample_test_set(self.eval_batch_size) self.eval_step("test", batch, batch_idx) @@ -406,9 +384,7 @@ def setup(self, stage: str) -> None: self.cfm_net = torch.compile(self.cfm_net) if self.nll_with_cfm: - self.cfm_prior = self.partial_prior( - device=self.device, scale=self.cfm_prior_std - ) + self.cfm_prior = self.partial_prior(device=self.device, scale=self.cfm_prior_std) def configure_optimizers(self) -> Dict[str, Any]: """Choose what optimizers and learning-rate schedulers to use in your optimization. diff --git a/dem/models/components/clipper.py b/dem/models/components/clipper.py index 5e20234..02f09f9 100644 --- a/dem/models/components/clipper.py +++ b/dem/models/components/clipper.py @@ -1,6 +1,7 @@ -import torch from typing import Optional +import torch + _EPSILON = 1e-6 @@ -28,15 +29,13 @@ def should_clip_log_rewards(self) -> bool: def clip_scores(self, scores: torch.Tensor) -> torch.Tensor: score_norms = torch.linalg.vector_norm(scores, dim=-1).detach() - clip_coefficient = torch.clamp( - self.max_score_norm / (score_norms + _EPSILON), max=1 - ) + clip_coefficient = torch.clamp(self.max_score_norm / (score_norms + _EPSILON), max=1) return scores * clip_coefficient.unsqueeze(-1) def clip_log_rewards(self, log_rewards: torch.Tensor) -> torch.Tensor: return log_rewards.clamp(min=self.min_log_reward) - + def wrap_grad_fxn(self, grad_fxn): def _run(*args, **kwargs): scores = grad_fxn(*args, **kwargs) diff --git a/dem/models/components/cnf.py b/dem/models/components/cnf.py index f860d74..d8889c7 100644 --- a/dem/models/components/cnf.py +++ b/dem/models/components/cnf.py @@ -110,7 +110,7 @@ def vecfield(x): div = hutch_trace(x, dx, torch.randn_like(x)) self.nfe += 1 - #print(div.mean()) + # print(div.mean()) return torch.cat([dx.detach(), div[:, None].detach()], dim=-1) def integrate(self, x): @@ -120,9 +120,7 @@ def integrate(self, x): time = torch.linspace(start_time, end_time, self.num_steps + 1, device=x.device) try: - return odeint( - self, x, t=time, method=method, atol=self.atol, rtol=self.rtol - ) + return odeint(self, x, t=time, method=method, atol=self.atol, rtol=self.rtol) except (RuntimeError, AssertionError) as e: print(e) diff --git a/dem/models/components/distribution_distances.py b/dem/models/components/distribution_distances.py index dc3e9a2..2cb3c73 100644 --- a/dem/models/components/distribution_distances.py +++ b/dem/models/components/distribution_distances.py @@ -9,14 +9,16 @@ def compute_distances(pred, true): - """computes distances between vectors.""" + """Computes distances between vectors.""" mse = torch.nn.functional.mse_loss(pred, true).item() me = math.sqrt(mse) mae = torch.mean(torch.abs(pred - true)).item() return mse, me, mae -def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, list], energy_function): +def compute_distribution_distances( + pred: torch.Tensor, true: Union[torch.Tensor, list], energy_function +): """computes distances between distributions. pred: [batch, times, dims] tensor true: [batch, times, dims] tensor or list[batch[i], dims] of length times @@ -35,16 +37,14 @@ def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, "Median_MSE", "Median_L2", "Median_L1", - "Eq-EMD2" + "Eq-EMD2", ] is_jagged = isinstance(true, list) pred_is_jagged = isinstance(pred, list) dists = [] to_return = [] names = [] - filtered_names = [ - name for name in NAMES if not is_jagged or not name.endswith("MMD") - ] + filtered_names = [name for name in NAMES if not is_jagged or not name.endswith("MMD")] ts = len(pred) if pred_is_jagged else pred.shape[1] for t in np.arange(ts): if pred_is_jagged: @@ -59,17 +59,17 @@ def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, w2 = wasserstein(a, b, power=2) if energy_function.is_molecule: - eq_emd2 = eot(a.reshape(-1, energy_function.n_particles, energy_function.n_spatial_dim).cpu(), - b.reshape(-1, energy_function.n_particles, energy_function.n_spatial_dim).cpu()) + eq_emd2 = eot( + a.reshape(-1, energy_function.n_particles, energy_function.n_spatial_dim).cpu(), + b.reshape(-1, energy_function.n_particles, energy_function.n_spatial_dim).cpu(), + ) if not pred_is_jagged and not is_jagged: mmd_linear = linear_mmd2(a, b).item() mmd_poly = poly_mmd2(a, b, d=2, alpha=1.0, c=2.0).item() mmd_rbf = mix_rbf_mmd2(a, b, sigma_list=[0.01, 0.1, 1, 10, 100]).item() mean_dists = compute_distances(torch.mean(a, dim=0), torch.mean(b, dim=0)) - median_dists = compute_distances( - torch.median(a, dim=0)[0], torch.median(b, dim=0)[0] - ) + median_dists = compute_distances(torch.median(a, dim=0)[0], torch.median(b, dim=0)[0]) if pred_is_jagged or is_jagged: dists.append((w1, w2, *mean_dists, *median_dists)) else: @@ -78,9 +78,7 @@ def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, (w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists, eq_emd2) ) else: - dists.append( - (w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists) - ) + dists.append((w1, w2, mmd_linear, mmd_poly, mmd_rbf, *mean_dists, *median_dists)) # For multipoint datasets add timepoint specific distances if ts > 1: names.extend([f"t{t+1}/{name}" for name in filtered_names]) @@ -91,9 +89,9 @@ def compute_distribution_distances(pred: torch.Tensor, true: Union[torch.Tensor, return names, to_return - -from scipy.optimize import linear_sum_assignment import ot as pot +from scipy.optimize import linear_sum_assignment + def find_rigid_alignment(A, B): """ @@ -139,12 +137,14 @@ def find_rigid_alignment(A, B): t = t.T return R, t.squeeze() + def ot(x0, x1): dists = torch.cdist(x0, x1) _, col_ind = linear_sum_assignment(dists) x1 = x1[col_ind] return x1 + def eot(x0, x1): M = [] for i in range(len(x0)): diff --git a/dem/models/components/egnn.py b/dem/models/components/egnn.py index 087c921..7496b31 100644 --- a/dem/models/components/egnn.py +++ b/dem/models/components/egnn.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from dem.utils.data_utils import remove_mean @@ -101,7 +102,7 @@ def __init__( coords_range=15, agg="sum", ): - super(EGNN, self).__init__() + super().__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf @@ -154,6 +155,7 @@ def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None): class E_GCL(nn.Module): """Graph Neural Net with global state and fixed number of nodes per graph. + Args: hidden_dim: Number of hidden units. num_nodes: Maximum number of nodes (for self-attentive pooling). @@ -177,7 +179,7 @@ def __init__( coords_range=1, agg="sum", ): - super(E_GCL, self).__init__() + super().__init__() input_edge = input_nf * 2 self.recurrent = recurrent self.attention = attention @@ -248,9 +250,7 @@ def node_model(self, x, edge_index, edge_attr, node_attr): out = x + out return out, agg - def coord_model( - self, coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask - ): + def coord_model(self, coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask): # print("coord_model", coord_diff, radial, edge_feat) row, col = edge_index if self.tanh: @@ -267,9 +267,7 @@ def coord_model( if node_mask is not None: # raise Exception('This part must be debugged before use') agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) - M = unsorted_segment_sum( - node_mask[col], row, num_segments=coord.size(0) - ) + M = unsorted_segment_sum(node_mask[col], row, num_segments=coord.size(0)) agg = agg / (M - 1) else: agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) diff --git a/dem/models/components/ema.py b/dem/models/components/ema.py index 9c3c70f..f63ff60 100644 --- a/dem/models/components/ema.py +++ b/dem/models/components/ema.py @@ -1,15 +1,15 @@ -"""Exponential moving average wrapper for torch.nn.Module""" +"""Exponential moving average wrapper for torch.nn.Module.""" import torch class EMAWrapper(torch.nn.Module): """Implements exponential moving average model wrapper. - Wraps a model where `ema.update_ema()` can be manually called to update ema - weights which are separately saved. + Wraps a model where `ema.update_ema()` can be manually called to update ema weights which are + separately saved. - with `ema.eval()` activates the EMA weights of the model for eval mode and - backs up the current weights. + with `ema.eval()` activates the EMA weights of the model for eval mode and backs up the current + weights. with `ema.train()` restores current weights. """ @@ -49,10 +49,8 @@ def train(self, mode: bool = True) -> None: super().train(use_training_mode) def _get_decay(self, num_updates: int) -> float: - """decay warmup magic from meta.""" - return min( - self.decay, (1 + num_updates) / (self.warmup_denominator + num_updates) - ) + """Decay warmup magic from meta.""" + return min(self.decay, (1 + num_updates) / (self.warmup_denominator + num_updates)) def update_ema(self) -> None: """Update the shadow params with a new EMA update.""" @@ -74,9 +72,8 @@ def copy_to_model(self) -> None: param.data.copy_(shadow.data) def backup(self) -> None: - """Create a backup of the model current params by creating a new copy - or copying in-place. - """ + """Create a backup of the model current params by creating a new copy or copying in- + place.""" if len(self.backup_params) > 0: for p, b in zip(self.model.parameters(), self.backup_params): b.data.copy_(p.data) @@ -84,6 +81,6 @@ def backup(self) -> None: self.backup_params = [param.clone() for param in self.model.parameters()] def restore_to_model(self) -> None: - """Move the backup parameters into the current model's parameters in-place""" + """Move the backup parameters into the current model's parameters in-place.""" for param, backup in zip(self.model.parameters(), self.backup_params): param.data.copy_(backup.data) diff --git a/dem/models/components/emd.py b/dem/models/components/emd.py index 7ecd18e..5654f11 100644 --- a/dem/models/components/emd.py +++ b/dem/models/components/emd.py @@ -13,8 +13,8 @@ def earth_mover_distance( return_matrix=False, metric="sqeuclidean", ): - """ - Returns the earth mover's distance between two point clouds + """Returns the earth mover's distance between two point clouds. + Parameters ---------- cloud1 : 2-D array @@ -43,9 +43,7 @@ def earth_mover_distance( weights2 = weights2.astype("float64") q_weights = weights2 / weights2.sum() - pairwise_dist = np.ascontiguousarray( - pairwise_distances(p, Y=q, metric=metric, n_jobs=-1) - ) + pairwise_dist = np.ascontiguousarray(pairwise_distances(p, Y=q, metric=metric, n_jobs=-1)) result = pot.emd2( p_weights, q_weights, pairwise_dist, numItermax=1e7, return_matrix=return_matrix @@ -58,8 +56,9 @@ def earth_mover_distance( def interpolate_with_ot(p0, p1, tmap, interp_frac, size): - """ - Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to p1 + """Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to + p1. + Parameters ---------- p0 : 2-D array @@ -106,8 +105,9 @@ def interpolate_with_ot(p0, p1, tmap, interp_frac, size): def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac): - """ - Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to p1 + """Interpolate between p0 and p1 at fraction t_interpolate knowing a transport map from p0 to + p1. + Parameters ---------- p0 : 2-D array @@ -149,9 +149,6 @@ def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac): p = p / p.sum(axis=0) choices = np.array([np.random.choice(I, p=p[i]) for i in range(I)]) return np.asarray( - [ - p0[i] * (1 - interp_frac) + p1[j] * interp_frac - for i, j in enumerate(choices) - ], + [p0[i] * (1 - interp_frac) + p1[j] * interp_frac for i, j in enumerate(choices)], dtype=np.float64, ) diff --git a/dem/models/components/lambda_weighter.py b/dem/models/components/lambda_weighter.py index 1175c19..1862590 100644 --- a/dem/models/components/lambda_weighter.py +++ b/dem/models/components/lambda_weighter.py @@ -1,6 +1,7 @@ -import torch from abc import ABC, abstractmethod +import torch + from .noise_schedules import BaseNoiseSchedule diff --git a/dem/models/components/mlp.py b/dem/models/components/mlp.py index 2fb2a6f..46e17eb 100644 --- a/dem/models/components/mlp.py +++ b/dem/models/components/mlp.py @@ -1,10 +1,11 @@ -import torch import copy + +import numpy as np +import torch +from einops import rearrange from torch import nn from torch.nn import functional as F from torch.nn.utils import spectral_norm -from einops import rearrange -import numpy as np class SinusoidalEmbedding(nn.Module): @@ -97,9 +98,7 @@ def forward(self, x: torch.Tensor): class Block(nn.Module): - def __init__( - self, size: int, t_emb_size: int = 0, add_t_emb=False, concat_t_emb=False - ): + def __init__(self, size: int, t_emb_size: int = 0, add_t_emb=False, concat_t_emb=False): super().__init__() in_size = size + t_emb_size if concat_t_emb else size @@ -146,10 +145,7 @@ def __init__( ) self.layers = nn.Sequential( nn.GELU(), - *[ - nn.Sequential(nn.Linear(channels, channels), nn.GELU()) - for _ in range(num_layers) - ], + *[nn.Sequential(nn.Linear(channels, channels), nn.GELU()) for _ in range(num_layers)], nn.Linear(channels, int(np.prod(self.out_shape))), ) if zero_init: @@ -158,12 +154,8 @@ def __init__( def forward(self, cond, inputs): cond = cond.view(-1, 1).expand((inputs.shape[0], 1)) - sin_embed_cond = torch.sin( - (self.timestep_coeff * cond.float()) + self.timestep_phase - ) - cos_embed_cond = torch.cos( - (self.timestep_coeff * cond.float()) + self.timestep_phase - ) + sin_embed_cond = torch.sin((self.timestep_coeff * cond.float()) + self.timestep_phase) + cos_embed_cond = torch.cos((self.timestep_coeff * cond.float()) + self.timestep_phase) embed_cond = self.timestep_embed( rearrange([sin_embed_cond, cos_embed_cond], "d b w -> b (d w)") ) @@ -440,7 +432,7 @@ def __init__( hidden2_size: int = 128, output_size: int = 1, ): - super(SpectralNormMLP, self).__init__() + super().__init__() # First hidden layer with spectral normalization self.fc1 = spectral_norm(nn.Linear(input_size, hidden1_size)) diff --git a/dem/models/components/mmd.py b/dem/models/components/mmd.py index 62de59a..9ec3976 100644 --- a/dem/models/components/mmd.py +++ b/dem/models/components/mmd.py @@ -104,11 +104,7 @@ def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): - 2.0 * K_XY_sum / (m * m) ) else: - mmd2 = ( - Kt_XX_sum / (m * (m - 1)) - + Kt_YY_sum / (m * (m - 1)) - - 2.0 * K_XY_sum / (m * m) - ) + mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) return mmd2 @@ -158,11 +154,7 @@ def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): - 2.0 * K_XY_sum / (m * m) ) else: - mmd2 = ( - Kt_XX_sum / (m * (m - 1)) - + Kt_YY_sum / (m * (m - 1)) - - 2.0 * K_XY_sum / (m * m) - ) + mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) var_est = ( 2.0 @@ -173,9 +165,7 @@ def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum ) - - (4.0 * m - 6.0) - / (m**3 * (m - 1.0) ** 3) - * (Kt_XX_sum**2 + Kt_YY_sum**2) + - (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2) + 4.0 * (m - 2.0) / (m**3 * (m - 1.0) ** 2) diff --git a/dem/models/components/noise_schedules.py b/dem/models/components/noise_schedules.py index e115969..8498416 100644 --- a/dem/models/components/noise_schedules.py +++ b/dem/models/components/noise_schedules.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -import torch import numpy as np import torch @@ -48,7 +47,7 @@ def g(self, t): return torch.sqrt(self.beta * self.power * (t ** (self.power - 1))) def h(self, t): - return self.beta * (t ** self.power) + return self.beta * (t**self.power) class SubLinearNoiseSchedule(BaseNoiseSchedule): @@ -72,11 +71,7 @@ def g(self, t): # Let sigma_d = sigma_max / sigma_min # Then g(t) = sigma_min * sigma_d^t * sqrt{2 * log(sigma_d)} # See Eq 192 in https://arxiv.org/pdf/2206.00364.pdf - return ( - self.sigma_min - * (self.sigma_diff**t) - * ((2 * np.log(self.sigma_diff)) ** 0.5) - ) + return self.sigma_min * (self.sigma_diff**t) * ((2 * np.log(self.sigma_diff)) ** 0.5) def h(self, t): # Let sigma_d = sigma_max / sigma_min diff --git a/dem/models/components/pis_net.py b/dem/models/components/pis_net.py index d33238f..a3c1bda 100644 --- a/dem/models/components/pis_net.py +++ b/dem/models/components/pis_net.py @@ -33,18 +33,14 @@ def _fn(t, x): def _fn(t, x): grad_fxn = torch.vmap(torch.func.grad(self.energy_function.__call__)) - grad = torch.clip( - grad_fxn(x), -self.lgv_clip, self.lgv_clip - ) + grad = torch.clip(grad_fxn(x), -self.lgv_clip, self.lgv_clip) f = torch.clip(self.f_func(t, x), -self.nn_clip, self.nn_clip) return f - self.lgv_coef(t) * grad elif f_format == "nn_grad": def _fn(t, x): - x_dot = torch.clip( - self.energy_function.score(x), -self.lgv_clip, self.lgv_clip - ) + x_dot = torch.clip(self.energy_function.score(x), -self.lgv_clip, self.lgv_clip) f_x = torch.clip(self.f_func(t, x), -self.nn_clip, self.nn_clip) return f_x * x_dot @@ -52,13 +48,9 @@ def _fn(t, x): self.grad_net = copy.deepcopy(self.f_func) def _fn(t, x): - x_dot = torch.clip( - self.energy_function(x), -self.lgv_clip, self.lgv_clip - ) + x_dot = torch.clip(self.energy_function(x), -self.lgv_clip, self.lgv_clip) f_x = torch.clip(self.f_func(t, x), -self.nn_clip, self.nn_clip) - f_x_dot = torch.clip( - self.grad_net(t, x_dot), -self.nn_clip, self.nn_clip - ) + f_x_dot = torch.clip(self.grad_net(t, x_dot), -self.nn_clip, self.nn_clip) return f_x + f_x_dot else: diff --git a/dem/models/components/prioritised_replay_buffer.py b/dem/models/components/prioritised_replay_buffer.py index 3dea03f..15aaacc 100644 --- a/dem/models/components/prioritised_replay_buffer.py +++ b/dem/models/components/prioritised_replay_buffer.py @@ -1,4 +1,5 @@ -from typing import NamedTuple, Tuple, Iterable, Callable, Optional +from typing import Callable, Iterable, NamedTuple, Optional, Tuple + import torch @@ -91,17 +92,13 @@ def __init__( print("Buffer not initialised, expected that checkpoint will be loaded.") @torch.no_grad() - def add( - self, x: torch.Tensor, log_w: torch.Tensor, log_q_old: torch.Tensor - ) -> None: - """Add a new batch of generated data to the replay buffer""" + def add(self, x: torch.Tensor, log_w: torch.Tensor, log_q_old: torch.Tensor) -> None: + """Add a new batch of generated data to the replay buffer.""" batch_size = x.shape[0] x = x.to(self.device) log_w = log_w.to(self.device) log_q_old = log_q_old.to(self.device) - indices = (torch.arange(batch_size) + self.current_index).to( - self.device - ) % self.max_length + indices = (torch.arange(batch_size) + self.current_index).to(self.device) % self.max_length self.buffer.x[indices] = x self.buffer.log_w[indices] = log_w self.buffer.log_q_old[indices] = log_q_old @@ -115,9 +112,8 @@ def add( def sample( self, batch_size: int, prioritize: Optional[bool] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Return a batch of sampled data, if the batch size is specified then the batch will have a - leading axis of length batch_size, otherwise the default self.batch_size will be used. - """ + """Return a batch of sampled data, if the batch size is specified then the batch will have + a leading axis of length batch_size, otherwise the default self.batch_size will be used.""" if not self.can_sample: raise Exception("Buffer must be at minimum length before calling sample") @@ -134,9 +130,9 @@ def sample( indices = torch.randint(max_index, (batch_size,)).to(self.device) else: if prioritize: - indices = sample_without_replacement( - self.buffer.log_w[:max_index], batch_size - ).to(self.device) + indices = sample_without_replacement(self.buffer.log_w[:max_index], batch_size).to( + self.device + ) else: indices = torch.randperm(max_index)[:batch_size].to(self.device) x, log_w, log_q_old, indices = ( @@ -275,13 +271,11 @@ def __len__(self): @torch.no_grad() def add(self, x: torch.Tensor, energy: torch.Tensor) -> None: - """Add a new batch of generated data to the replay buffer""" + """Add a new batch of generated data to the replay buffer.""" batch_size = x.shape[0] x = x.to(self.device) energy = energy.to(self.device) - indices = (torch.arange(batch_size) + self.current_index).to( - self.device - ) % self.max_length + indices = (torch.arange(batch_size) + self.current_index).to(self.device) % self.max_length self.buffer.x[indices] = x self.buffer.energy[indices] = energy new_index = self.current_index + batch_size @@ -309,9 +303,8 @@ def get_last_n_inserted(self, num_to_get: int) -> Tuple[torch.Tensor, torch.Tens def sample( self, batch_size: int, prioritize: Optional[bool] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Return a batch of sampled data, if the batch size is specified then the batch will have a - leading axis of length batch_size, otherwise the default self.batch_size will be used. - """ + """Return a batch of sampled data, if the batch size is specified then the batch will have + a leading axis of length batch_size, otherwise the default self.batch_size will be used.""" if not self.can_sample: raise Exception("Buffer must be at minimum length before calling sample") @@ -384,8 +377,7 @@ def sample_n_batches( log_w_batches = torch.chunk(log_w, n_batches) indices_batches = torch.chunk(indices, n_batches) dataset = [ - (x, log_w, indxs) - for x, log_w, indxs in zip(x_batches, log_w_batches, indices_batches) + (x, log_w, indxs) for x, log_w, indxs in zip(x_batches, log_w_batches, indices_batches) ] return dataset @@ -426,8 +418,6 @@ def load(self, path): buffer = PrioritisedReplayBuffer(dim, length, min_sample_length, initial_sampler) n_batches = 3 for i in range(100): - buffer.add( - torch.ones(batch_size, dim), torch.zeros(batch_size), torch.ones(batch_size) - ) + buffer.add(torch.ones(batch_size, dim), torch.zeros(batch_size), torch.ones(batch_size)) x, log_w, log_q_old, indices = buffer.sample(batch_size) buffer.adjust(log_w + 1, log_q_old + 0.1, indices) diff --git a/dem/models/components/replay_buffer.py b/dem/models/components/replay_buffer.py index c74da0c..4d7d564 100644 --- a/dem/models/components/replay_buffer.py +++ b/dem/models/components/replay_buffer.py @@ -1,4 +1,5 @@ -from typing import NamedTuple, Tuple, Iterable, Callable +from typing import Callable, Iterable, NamedTuple, Tuple + import torch @@ -68,13 +69,11 @@ def __init__( @torch.no_grad() def add(self, x: torch.Tensor, log_w: torch.Tensor): - """Add a batch of generated data to the replay buffer""" + """Add a batch of generated data to the replay buffer.""" batch_size = x.shape[0] x = x.to(self.device) log_w = log_w.to(self.device) - indices = (torch.arange(batch_size) + self.current_index).to( - self.device - ) % self.max_length + indices = (torch.arange(batch_size) + self.current_index).to(self.device) % self.max_length self.buffer.x[indices] = x self.buffer.log_w[indices] = log_w self.buffer.add_count[indices] = self.current_add_count @@ -102,17 +101,14 @@ def get_last_n_inserted(self, num_to_get: int) -> Tuple[torch.Tensor, torch.Tens @torch.no_grad() def sample(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: - """Return a batch of sampled data, if the batch size is specified then the batch will have a - leading axis of length batch_size, otherwise the default self.batch_size will be used. - """ + """Return a batch of sampled data, if the batch size is specified then the batch will have + a leading axis of length batch_size, otherwise the default self.batch_size will be used.""" if not self.can_sample: raise Exception("Buffer must be at minimum length before calling sample") max_index = self.max_length if self.is_full else self.current_index rank = self.current_add_count - self.buffer.add_count[:max_index] probs = torch.pow(1 / rank, self.temperature) - indices = torch.multinomial( - probs, num_samples=batch_size, replacement=False - ).to( + indices = torch.multinomial(probs, num_samples=batch_size, replacement=False).to( self.device ) # sample uniformly return self.buffer.x[indices], self.buffer.log_w[indices] @@ -136,9 +132,7 @@ def sample_n_batches( length = n_batches_total_length * batch_size min_sample_length = int(length * 0.5) initial_sampler = lambda: (torch.ones(batch_size, dim), torch.zeros(batch_size)) - buffer = ReplayBuffer( - dim, length, min_sample_length, initial_sampler, temperature=0.0 - ) + buffer = ReplayBuffer(dim, length, min_sample_length, initial_sampler, temperature=0.0) n_batches = 3 for i in range(100): buffer.add(torch.ones(batch_size, dim), torch.zeros(batch_size)) diff --git a/dem/models/components/scaling_wrapper.py b/dem/models/components/scaling_wrapper.py index 04bded1..a9173c1 100644 --- a/dem/models/components/scaling_wrapper.py +++ b/dem/models/components/scaling_wrapper.py @@ -1,11 +1,10 @@ from typing import Optional + import torch class ScalingWrapper(torch.nn.Module): - """ - (Tries to) normalize data and blah blah blah - """ + """(Tries to) normalize data and blah blah blah.""" def __init__( self, diff --git a/dem/models/components/score_estimator.py b/dem/models/components/score_estimator.py index 2e1ccb9..dd47c7e 100644 --- a/dem/models/components/score_estimator.py +++ b/dem/models/components/score_estimator.py @@ -1,26 +1,17 @@ -import torch import numpy as np +import torch + from dem.energies.base_energy_function import BaseEnergyFunction -from dem.models.components.noise_schedules import BaseNoiseSchedule from dem.models.components.clipper import Clipper +from dem.models.components.noise_schedules import BaseNoiseSchedule def wrap_for_richardsons(score_estimator): def _fxn(t, x, energy_function, noise_schedule, num_mc_samples): - bigger_samples = score_estimator( - t, - x, - energy_function, - noise_schedule, - num_mc_samples - ) + bigger_samples = score_estimator(t, x, energy_function, noise_schedule, num_mc_samples) smaller_samples = score_estimator( - t, - x, - energy_function, - noise_schedule, - int(num_mc_samples / 2) + t, x, energy_function, noise_schedule, int(num_mc_samples / 2) ) return (2 * bigger_samples) - smaller_samples @@ -62,8 +53,6 @@ def estimate_grad_Rt( t = t.unsqueeze(0).repeat(len(x)) grad_fxn = torch.func.grad(log_expectation_reward, argnums=1) - vmapped_fxn = torch.vmap( - grad_fxn, in_dims=(0, 0, None, None, None), randomness="different" - ) + vmapped_fxn = torch.vmap(grad_fxn, in_dims=(0, 0, None, None, None), randomness="different") return vmapped_fxn(t, x, energy_function, noise_schedule, num_mc_samples) diff --git a/dem/models/components/score_scaler.py b/dem/models/components/score_scaler.py index b4f1b2c..3a04cae 100644 --- a/dem/models/components/score_scaler.py +++ b/dem/models/components/score_scaler.py @@ -1,14 +1,13 @@ -import torch from abc import ABC, abstractmethod +import torch + from .noise_schedules import BaseNoiseSchedule class BaseScoreScaler(ABC): @abstractmethod - def scale_target_score( - self, target_score: torch.Tensor, times: torch.Tensor - ) -> torch.Tensor: + def scale_target_score(self, target_score: torch.Tensor, times: torch.Tensor) -> torch.Tensor: pass @abstractmethod @@ -27,21 +26,15 @@ def __init__( self.constant_scaling_factor = constant_scaling_factor self.epsilon = epsilon - def _get_scale_factor( - self, score: torch.Tensor, times: torch.Tensor - ) -> torch.Tensor: + def _get_scale_factor(self, score: torch.Tensor, times: torch.Tensor) -> torch.Tensor: # call view to expand h_t to the number of dimensions of target_score - h_t = self.noise_schedule.h(times).view( - -1, *(1 for _ in range(score.ndim - times.ndim)) - ) + h_t = self.noise_schedule.h(times).view(-1, *(1 for _ in range(score.ndim - times.ndim))) h_t = h_t * self.constant_scaling_factor return (h_t * self.constant_scaling_factor) + self.epsilon - def scale_target_score( - self, target_score: torch.Tensor, times: torch.Tensor - ) -> torch.Tensor: + def scale_target_score(self, target_score: torch.Tensor, times: torch.Tensor) -> torch.Tensor: return target_score * self._get_scale_factor(target_score, times) def _build_wrapper_class(self): diff --git a/dem/models/components/sde_integration.py b/dem/models/components/sde_integration.py index 710b719..2b3b146 100644 --- a/dem/models/components/sde_integration.py +++ b/dem/models/components/sde_integration.py @@ -1,9 +1,9 @@ import numpy as np import torch +from dem.energies.base_energy_function import BaseEnergyFunction from dem.models.components.sdes import VEReverseSDE from dem.utils.data_utils import remove_mean -from dem.energies.base_energy_function import BaseEnergyFunction def euler_maruyama_step( @@ -27,9 +27,7 @@ def integrate_pfode( start_time = 1.0 if reverse_time else 0.0 end_time = 1.0 - start_time - times = torch.linspace( - start_time, end_time, num_integration_steps + 1, device=x0.device - )[:-1] + times = torch.linspace(start_time, end_time, num_integration_steps + 1, device=x0.device)[:-1] x = x0 samples = [] @@ -54,9 +52,7 @@ def integrate_sde( start_time = time_range if reverse_time else 0.0 end_time = time_range - start_time - times = torch.linspace( - start_time, end_time, num_integration_steps + 1, device=x0.device - )[:-1] + times = torch.linspace(start_time, end_time, num_integration_steps + 1, device=x0.device)[:-1] x = x0 samples = [] diff --git a/dem/models/components/sdes.py b/dem/models/components/sdes.py index 97e0490..83f183c 100644 --- a/dem/models/components/sdes.py +++ b/dem/models/components/sdes.py @@ -1,5 +1,6 @@ import torch + class SDE(torch.nn.Module): noise_type = "diagonal" sde_type = "ito" @@ -13,7 +14,7 @@ def f(self, t, x): if t.dim() == 0: # repeat the same time for all points if we have a scalar time t = t * torch.ones(x.shape[0]).to(x.device) - + return self.drift(t, x) def g(self, t, x): @@ -52,6 +53,4 @@ def g(self, t, x): g = self.noise_schedule.g(t) if g.ndim > 0: return g.unsqueeze(1) - return torch.cat( - [torch.full_like(x[..., :-1], g), torch.zeros_like(x[..., -1:])], dim=-1 - ) \ No newline at end of file + return torch.cat([torch.full_like(x[..., :-1], g), torch.zeros_like(x[..., -1:])], dim=-1) diff --git a/dem/models/dem_module.py b/dem/models/dem_module.py index 37feffc..8ab3f34 100644 --- a/dem/models/dem_module.py +++ b/dem/models/dem_module.py @@ -1,12 +1,12 @@ -from typing import Any, Dict, Optional import time +from typing import Any, Dict, Optional -from hydra.utils import get_original_cwd import hydra import matplotlib.pyplot as plt -import ot as pot import numpy as np +import ot as pot import torch +from hydra.utils import get_original_cwd from lightning import LightningModule from lightning.pytorch.loggers import WandbLogger from torchcfm.conditional_flow_matching import ( @@ -55,10 +55,7 @@ def t_stratified_loss(batch_t, batch_loss, num_bins=5, loss_name=None): def get_wandb_logger(loggers): - """ - Gets the wandb logger if it is the - list of loggers otherwise returns None. - """ + """Gets the wandb logger if it is the list of loggers otherwise returns None.""" wandb_logger = None for logger in loggers: if isinstance(logger, WandbLogger): @@ -168,9 +165,7 @@ def __init__( self.net = EMAWrapper(self.net) self.cfm_net = EMAWrapper(self.cfm_net) if input_scaling_factor is not None or output_scaling_factor is not None: - self.net = ScalingWrapper( - self.net, input_scaling_factor, output_scaling_factor - ) + self.net = ScalingWrapper(self.net, input_scaling_factor, output_scaling_factor) self.cfm_net = ScalingWrapper( self.cfm_net, input_scaling_factor, output_scaling_factor @@ -210,7 +205,6 @@ def __init__( self.cfm_prior_std = cfm_prior_std self.compute_nll_on_train_data = compute_nll_on_train_data - flow_matcher = ConditionalFlowMatcher if use_otcfm: flow_matcher = ExactOptimalTransportConditionalFlowMatcher @@ -321,9 +315,7 @@ def get_cfm_loss(self, samples: torch.Tensor) -> torch.Tensor: x1 = samples x1 = self.energy_function.unnormalize(x1) - t, xt, ut = self.conditional_flow_matcher.sample_location_and_conditional_flow( - x0, x1 - ) + t, xt, ut = self.conditional_flow_matcher.sample_location_and_conditional_flow(x0, x1) if self.energy_function.is_molecule and self.cfm_sigma != 0: xt = remove_mean( @@ -341,12 +333,14 @@ def get_cfm_loss(self, samples: torch.Tensor) -> torch.Tensor: def should_train_cfm(self, batch_idx: int) -> bool: return self.nll_with_cfm or self.hparams.debug_use_train_data - def get_score_loss(self, times: torch.Tensor, - samples: torch.Tensor, - noised_samples: torch.Tensor) -> torch.Tensor: + def get_score_loss( + self, times: torch.Tensor, samples: torch.Tensor, noised_samples: torch.Tensor + ) -> torch.Tensor: predicted_score = self.forward(times, noised_samples) - true_score = -(noised_samples - samples)/ (self.noise_schedule.h(times).unsqueeze(1) + 1e-4) + true_score = -(noised_samples - samples) / ( + self.noise_schedule.h(times).unsqueeze(1) + 1e-4 + ) error_norms = (predicted_score - true_score).pow(2).mean(-1) return error_norms @@ -370,14 +364,10 @@ def get_loss(self, times: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: estimated_score = self.clipper.clip_scores(estimated_score) if self.energy_function.is_molecule: - estimated_score = estimated_score.reshape( - -1, self.energy_function.dimensionality - ) + estimated_score = estimated_score.reshape(-1, self.energy_function.dimensionality) if self.score_scaler is not None: - estimated_score = self.score_scaler.scale_target_score( - estimated_score, times - ) + estimated_score = self.score_scaler.scale_target_score(estimated_score, times) predicted_score = self.forward(times, samples) @@ -389,21 +379,18 @@ def training_step(self, batch, batch_idx): loss = 0.0 if not self.hparams.debug_use_train_data: if self.hparams.use_buffer: - iter_samples, _, _ = self.buffer.sample( - self.num_samples_to_sample_from_buffer - ) + iter_samples, _, _ = self.buffer.sample(self.num_samples_to_sample_from_buffer) else: iter_samples = self.prior.sample(self.num_samples_to_sample_from_buffer) # Uncomment for SM - #iter_samples = self.energy_function.sample_train_set(self.num_samples_to_sample_from_buffer) + # iter_samples = self.energy_function.sample_train_set(self.num_samples_to_sample_from_buffer) times = torch.rand( (self.num_samples_to_sample_from_buffer,), device=iter_samples.device ) noised_samples = iter_samples + ( - torch.randn_like(iter_samples) - * self.noise_schedule.h(times).sqrt().unsqueeze(-1) + torch.randn_like(iter_samples) * self.noise_schedule.h(times).sqrt().unsqueeze(-1) ) if self.energy_function.is_molecule: @@ -415,11 +402,9 @@ def training_step(self, batch, batch_idx): dem_loss = self.get_loss(times, noised_samples) # Uncomment for SM - #dem_loss = self.get_score_loss(times, iter_samples, noised_samples) + # dem_loss = self.get_score_loss(times, iter_samples, noised_samples) self.log_dict( - t_stratified_loss( - times, dem_loss, loss_name="train/stratified/dem_loss" - ) + t_stratified_loss(times, dem_loss, loss_name="train/stratified/dem_loss") ) dem_loss = dem_loss.mean() loss = loss + dem_loss @@ -450,9 +435,7 @@ def training_step(self, batch, batch_idx): cfm_loss = self.get_cfm_loss(cfm_samples) self.log_dict( - t_stratified_loss( - times, cfm_loss, loss_name="train/stratified/cfm_loss" - ) + t_stratified_loss(times, cfm_loss, loss_name="train/stratified/cfm_loss") ) cfm_loss = cfm_loss.mean() self.cfm_train_loss(cfm_loss) @@ -545,9 +528,7 @@ def on_train_epoch_end(self) -> None: ) self.last_energies = self.energy_function(self.last_samples) else: - self.last_samples = self.generate_samples( - diffusion_scale=self.diffusion_scale - ) + self.last_samples = self.generate_samples(diffusion_scale=self.diffusion_scale) self.last_energies = self.energy_function(self.last_samples) self.buffer.add(self.last_samples, self.last_energies) @@ -561,8 +542,9 @@ def on_train_epoch_end(self) -> None: def _log_energy_w2(self, prefix="val"): if prefix == "test": data_set = self.energy_function.sample_val_set(self.eval_batch_size) - generated_samples = self.generate_samples(num_samples=self.eval_batch_size, - diffusion_scale=self.diffusion_scale) + generated_samples = self.generate_samples( + num_samples=self.eval_batch_size, diffusion_scale=self.diffusion_scale + ) generated_energies = self.energy_function(generated_samples) else: if len(self.buffer) < self.eval_batch_size: @@ -571,10 +553,7 @@ def _log_energy_w2(self, prefix="val"): _, generated_energies = self.buffer.get_last_n_inserted(self.eval_batch_size) energies = self.energy_function(self.energy_function.normalize(data_set)) - energy_w2 = pot.emd2_1d( - energies.cpu().numpy(), - generated_energies.cpu().numpy() - ) + energy_w2 = pot.emd2_1d(energies.cpu().numpy(), generated_energies.cpu().numpy()) self.log( f"{prefix}/energy_w2", @@ -584,13 +563,15 @@ def _log_energy_w2(self, prefix="val"): prog_bar=True, ) - def _log_dist_w2(self, prefix="val"): if prefix == "test": - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() data_set = self.energy_function.sample_val_set(self.eval_batch_size) - generated_samples = self.generate_samples(num_samples=self.eval_batch_size, - diffusion_scale=self.diffusion_scale) + generated_samples = self.generate_samples( + num_samples=self.eval_batch_size, diffusion_scale=self.diffusion_scale + ) else: if len(self.buffer) < self.eval_batch_size: return @@ -599,7 +580,7 @@ def _log_dist_w2(self, prefix="val"): dist_w2 = pot.emd2_1d( self.energy_function.interatomic_dist(generated_samples).cpu().numpy().reshape(-1), - self.energy_function.interatomic_dist(data_set).cpu().numpy().reshape(-1) + self.energy_function.interatomic_dist(data_set).cpu().numpy().reshape(-1), ) self.log( f"{prefix}/dist_w2", @@ -611,22 +592,32 @@ def _log_dist_w2(self, prefix="val"): def _log_dist_total_var(self, prefix="val"): if prefix == "test": - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() data_set = self.energy_function.sample_val_set(self.eval_batch_size) - generated_samples = self.generate_samples(num_samples=self.eval_batch_size, - diffusion_scale=self.diffusion_scale) + generated_samples = self.generate_samples( + num_samples=self.eval_batch_size, diffusion_scale=self.diffusion_scale + ) else: if len(self.buffer) < self.eval_batch_size: return data_set = self.energy_function.sample_test_set(self.eval_batch_size) generated_samples, _ = self.buffer.get_last_n_inserted(self.eval_batch_size) - generated_samples_dists = self.energy_function.interatomic_dist(generated_samples).cpu().numpy().reshape(-1), + generated_samples_dists = ( + self.energy_function.interatomic_dist(generated_samples).cpu().numpy().reshape(-1), + ) data_set_dists = self.energy_function.interatomic_dist(data_set).cpu().numpy().reshape(-1) H_data_set, x_data_set = np.histogram(data_set_dists, bins=200) H_generated_samples, _ = np.histogram(generated_samples_dists, bins=(x_data_set)) - total_var = 0.5 * np.abs(H_data_set/H_data_set.sum() - H_generated_samples/H_generated_samples.sum()).sum() + total_var = ( + 0.5 + * np.abs( + H_data_set / H_data_set.sum() - H_generated_samples / H_generated_samples.sum() + ).sum() + ) self.log( f"{prefix}/dist_total_var", @@ -637,9 +628,7 @@ def _log_dist_total_var(self, prefix="val"): ) def compute_log_z(self, cnf, prior, samples, prefix, name): - nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll( - cnf, prior, samples - ) + nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll(cnf, prior, samples) # this is stupid we should fix the normalization in the energy function logz = self.energy_function(self.energy_function.normalize(samples)) + nll logz_metric = getattr(self, f"{prefix}_{name}logz") @@ -654,9 +643,7 @@ def compute_log_z(self, cnf, prior, samples, prefix, name): def compute_and_log_nll(self, cnf, prior, samples, prefix, name): cnf.nfe = 0.0 - nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll( - cnf, prior, samples - ) + nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll(cnf, prior, samples) nfe_metric = getattr(self, f"{prefix}_{name}nfe") nll_metric = getattr(self, f"{prefix}_{name}nll") logdetjac_metric = getattr(self, f"{prefix}_{name}nll_logdetjac") @@ -713,7 +700,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: num_samples=self.eval_batch_size, diffusion_scale=self.diffusion_scale ) - # sample eval_batch_size from generated samples from dem to match dimenstions + # sample eval_batch_size from generated samples from dem to match dimensions # required for distribution metrics if len(backwards_samples) != self.eval_batch_size: indices = torch.randperm(len(backwards_samples))[: self.eval_batch_size] @@ -744,9 +731,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: loss_metric = self.val_loss if prefix == "val" else self.test_loss loss_metric(loss) - self.log( - f"{prefix}/loss", loss_metric, on_step=True, on_epoch=True, prog_bar=True - ) + self.log(f"{prefix}/loss", loss_metric, on_step=True, on_epoch=True, prog_bar=True) to_log = { "data_0": batch, @@ -759,9 +744,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: self.dem_cnf, self.prior, batch, prefix, "dem_" ) to_log["gen_1_dem"] = forwards_samples - self.compute_log_z( - self.cfm_cnf, self.prior, backwards_samples, prefix, "dem_" - ) + self.compute_log_z(self.cfm_cnf, self.prior, backwards_samples, prefix, "dem_") if self.nll_with_cfm: forwards_samples = self.compute_and_log_nll( self.cfm_cnf, self.cfm_prior, batch, prefix, "" @@ -777,9 +760,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: ) if self.compute_nll_on_train_data: - train_samples = self.energy_function.sample_train_set( - self.eval_batch_size - ) + train_samples = self.energy_function.sample_train_set(self.eval_batch_size) forwards_samples = self.compute_and_log_nll( self.cfm_cnf, self.cfm_prior, train_samples, prefix, "train_" ) @@ -789,9 +770,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: self.cfm_prior.sample(self.eval_batch_size), )[-1] # backwards_samples = self.generate_cfm_samples(self.eval_batch_size) - self.compute_log_z( - self.cfm_cnf, self.cfm_prior, backwards_samples, prefix, "" - ) + self.compute_log_z(self.cfm_cnf, self.cfm_prior, backwards_samples, prefix, "") self.eval_step_outputs.append(to_log) @@ -824,9 +803,9 @@ def eval_epoch_end(self, prefix: str): self.last_samples, self.last_energies, wandb_logger, - unprioritized_buffer_samples = unprioritized_buffer_samples, - cfm_samples = cfm_samples, - replay_buffer = self.buffer, + unprioritized_buffer_samples=unprioritized_buffer_samples, + cfm_samples=cfm_samples, + replay_buffer=self.buffer, ) else: @@ -841,7 +820,8 @@ def eval_epoch_end(self, prefix: str): # pad with time dimension 1 names, dists = compute_distribution_distances( self.energy_function.unnormalize(outputs["gen_0"])[:, None], - outputs["data_0"][:, None], self.energy_function + outputs["data_0"][:, None], + self.energy_function, ) names = [f"{prefix}/{name}" for name in names] d = dict(zip(names, dists)) @@ -868,14 +848,13 @@ def on_test_epoch_end(self) -> None: for i in range(n_batches): start = time.time() samples = self.generate_samples( - num_samples=batch_size, diffusion_scale=self.diffusion_scale - ) - final_samples.append(samples + num_samples=batch_size, diffusion_scale=self.diffusion_scale ) + final_samples.append(samples) end = time.time() print(f"batch {i} took {end - start:0.2f}s") - if i==0: + if i == 0: self.energy_function.log_on_epoch_end( samples, self.energy_function(samples), @@ -888,6 +867,7 @@ def on_test_epoch_end(self) -> None: torch.save(final_samples, path) print(f"Saving samples to {path}") import os + os.makedirs(self.energy_function.name, exist_ok=True) path2 = f"{self.energy_function.name}/samples_{self.hparams.version}_{self.num_samples_to_save}.pt" torch.save(final_samples, path2) @@ -914,9 +894,7 @@ def _grad_fxn(t, x): reverse_sde = VEReverseSDE(_grad_fxn, self.noise_schedule) - self.prior = self.partial_prior( - device=self.device, scale=self.noise_schedule.h(1) ** 0.5 - ) + self.prior = self.partial_prior(device=self.device, scale=self.noise_schedule.h(1) ** 0.5) if self.init_from_prior: init_states = self.prior.sample(self.num_init_samples) else: @@ -932,9 +910,7 @@ def _grad_fxn(t, x): self.cfm_net = torch.compile(self.cfm_net) if self.nll_with_cfm: - self.cfm_prior = self.partial_prior( - device=self.device, scale=self.cfm_prior_std - ) + self.cfm_prior = self.partial_prior(device=self.device, scale=self.cfm_prior_std) def configure_optimizers(self) -> Dict[str, Any]: """Choose what optimizers and learning-rate schedulers to use in your optimization. diff --git a/dem/models/mnist_module.py b/dem/models/mnist_module.py index 7081aaa..5d303ac 100644 --- a/dem/models/mnist_module.py +++ b/dem/models/mnist_module.py @@ -125,12 +125,8 @@ def training_step( # update and log metrics self.train_loss(loss) self.train_acc(preds, targets) - self.log( - "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True - ) - self.log( - "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True - ) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) # return loss or backpropagation will fail return loss @@ -139,9 +135,7 @@ def on_train_epoch_end(self) -> None: "Lightning hook that is called when a training epoch ends." pass - def validation_step( - self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> None: + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: """Perform a single validation step on a batch of data from the validation set. :param batch: A batch of data (a tuple) containing the input tensor of images and target @@ -162,13 +156,9 @@ def on_validation_epoch_end(self) -> None: self.val_acc_best(acc) # update best so far val acc # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object # otherwise metric would be reset by lightning after each epoch - self.log( - "val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True - ) + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) - def test_step( - self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> None: + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: """Perform a single test step on a batch of data from the test set. :param batch: A batch of data (a tuple) containing the input tensor of images and target @@ -180,9 +170,7 @@ def test_step( # update and log metrics self.test_loss(loss) self.test_acc(preds, targets) - self.log( - "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True - ) + self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) def on_test_epoch_end(self) -> None: diff --git a/dem/models/pis_module.py b/dem/models/pis_module.py index 7ba9083..3acb3be 100644 --- a/dem/models/pis_module.py +++ b/dem/models/pis_module.py @@ -1,3 +1,4 @@ +import math from typing import Any, Dict, Optional, Tuple import matplotlib.pyplot as plt @@ -5,12 +6,11 @@ import torch from lightning import LightningModule from lightning.pytorch.loggers import WandbLogger -from torchdiffeq import odeint - from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, ) +from torchdiffeq import odeint from torchdyn.core import NeuralODE from torchmetrics import MeanMetric @@ -21,22 +21,22 @@ from .components.cnf import CNF from .components.distribution_distances import compute_distribution_distances from .components.ema import EMAWrapper -from .components.noise_schedules import BaseNoiseSchedule from .components.lambda_weighter import BaseLambdaWeighter - +from .components.noise_schedules import BaseNoiseSchedule from .components.prioritised_replay_buffer import PrioritisedReplayBuffer from .components.scaling_wrapper import ScalingWrapper from .components.score_estimator import estimate_grad_Rt from .components.score_scaler import BaseScoreScaler from .components.sde_integration import integrate_sde from .components.sdes import SDE -import math logtwopi = math.log(2 * math.pi) + def logmeanexp(x, dim=0): return x.logsumexp(dim) - math.log(x.shape[dim]) + def t_stratified_loss(batch_t, batch_loss, num_bins=5, loss_name=None): """Stratify loss by binning t.""" flat_losses = batch_loss.flatten().detach().cpu().numpy() @@ -56,11 +56,9 @@ def t_stratified_loss(batch_t, batch_loss, num_bins=5, loss_name=None): stratified_losses[t_range] = range_loss return stratified_losses + def get_wandb_logger(loggers): - """ - Gets the wandb logger if it is the - list of loggers otherwise returns None. - """ + """Gets the wandb logger if it is the list of loggers otherwise returns None.""" wandb_logger = None for logger in loggers: if isinstance(logger, WandbLogger): @@ -139,7 +137,7 @@ def __init__( diffusion_scale=1.0, cfm_loss_weight=1.0, pis_scale=1.0, - time_range=5., + time_range=5.0, use_ema=False, debug_use_train_data=False, init_from_prior=False, @@ -189,9 +187,7 @@ def __init__( self.net = EMAWrapper(self.net) self.cfm_net = EMAWrapper(self.cfm_net) if input_scaling_factor is not None or output_scaling_factor is not None: - self.net = ScalingWrapper( - self.net, input_scaling_factor, output_scaling_factor - ) + self.net = ScalingWrapper(self.net, input_scaling_factor, output_scaling_factor) self.cfm_net = ScalingWrapper( self.cfm_net, input_scaling_factor, output_scaling_factor @@ -204,7 +200,9 @@ def __init__( self.net = self.score_scaler.wrap_model_for_unscaling(self.net) self.cfm_net = self.score_scaler.wrap_model_for_unscaling(self.cfm_net) - self.cfm_cnf = CNF(self.cfm_net, is_diffusion=False, use_exact_likelihood=use_exact_likelihood) + self.cfm_cnf = CNF( + self.cfm_net, is_diffusion=False, use_exact_likelihood=use_exact_likelihood + ) self.num_init_samples = num_init_samples self.num_estimator_mc_samples = num_estimator_mc_samples @@ -268,14 +266,14 @@ def score(self, x): self.energy_function(copy_x).sum().backward() lgv_data = copy_x.grad.data return lgv_data - + def drift(self, t, x): state = x[:, :-1] state = torch.nan_to_num(state) grad = torch.clip(self.score(state), -1e4, 1e4) f = torch.clip(self.net(t, state), -1e4, 1e4) dx = torch.nan_to_num(f + self.tcond(t) * grad) - norm = 0.5 * (dx ** 2).sum(dim=1, keepdim=True) + norm = 0.5 * (dx**2).sum(dim=1, keepdim=True) return torch.cat([dx * self.pis_scale, norm], dim=-1) def diffusion(self, t, x): @@ -284,7 +282,7 @@ def diffusion(self, t, x): return coeff def fwd_traj(self): - start_time = 0. + start_time = 0.0 end_time = self.time_range dt = self.time_range / self.num_integration_steps @@ -292,9 +290,7 @@ def fwd_traj(self): start_time, end_time, self.num_integration_steps + 1, device=self.device )[:-1] - state = torch.zeros( - self.eval_batch_size, self.dim + 1, device=self.device - ) + state = torch.zeros(self.eval_batch_size, self.dim + 1, device=self.device) logpf = torch.zeros(self.eval_batch_size, device=self.device) logpb = torch.zeros(self.eval_batch_size, device=self.device) @@ -305,20 +301,22 @@ def fwd_traj(self): std = self.pis_sde.g(t, state) state_ = state + dt * dx + std * noise * np.sqrt(dt) - logpf += -0.5 * (noise[..., :-1] ** 2 + logtwopi + np.log(dt) + torch.log(std[..., :-1] ** 2)).sum(1) + logpf += -0.5 * ( + noise[..., :-1] ** 2 + logtwopi + np.log(dt) + torch.log(std[..., :-1] ** 2) + ).sum(1) if t > 0: back_mean = state_[..., :-1] - dt * state_[..., :-1] / (t + dt) - back_var = (self.pis_scale ** 2) * dt * t / (t + dt) + back_var = (self.pis_scale**2) * dt * t / (t + dt) noise_backward = (state[..., :-1] - back_mean) / torch.sqrt(back_var) - logpb += -0.5 * (noise_backward ** 2 + logtwopi + torch.log(back_var)).sum(1) + logpb += -0.5 * (noise_backward**2 + logtwopi + torch.log(back_var)).sum(1) state = state_ - + return state[..., :-1], logpf, logpb - + def bwd_traj(self, data): - start_time = 0. + start_time = 0.0 end_time = self.time_range dt = self.time_range / self.num_integration_steps @@ -334,23 +332,26 @@ def bwd_traj(self, data): for t in times: if t > dt: back_mean = state - dt * state / t - back_var = ((self.pis_scale ** 2) * dt * (t - dt)) / t + back_var = ((self.pis_scale**2) * dt * (t - dt)) / t noise = torch.randn_like(state, device=self.device) state_ = back_mean + torch.sqrt(back_var) * noise - log_pb += -0.5 * (noise ** 2 + logtwopi + torch.log(back_var)).sum(1) + log_pb += -0.5 * (noise**2 + logtwopi + torch.log(back_var)).sum(1) else: state_ = torch.zeros_like(state, device=self.device) aug_state = torch.cat([state_, torch.zeros_like(state_[..., :1])], dim=-1) forward_mean = self.pis_sde.f(t - dt, aug_state)[..., :-1] forward_var = self.pis_sde.g(t - dt, aug_state)[..., :-1] ** 2 - - noise = ((state - state_) - dt * forward_mean) / (np.sqrt(dt) * torch.sqrt(forward_var)) - log_pf += -0.5 * (noise ** 2 + logtwopi + np.log(dt) + torch.log(forward_var)).sum( - 1) - + + noise = ((state - state_) - dt * forward_mean) / ( + np.sqrt(dt) * torch.sqrt(forward_var) + ) + log_pf += -0.5 * (noise**2 + logtwopi + np.log(dt) + torch.log(forward_var)).sum( + 1 + ) + state = state_ - + return log_pf, log_pb def step_with_uw(self, t, state, dt): @@ -373,9 +374,7 @@ def get_cfm_loss(self, samples: torch.Tensor) -> torch.Tensor: # ) x1 = samples - t, xt, ut = self.conditional_flow_matcher.sample_location_and_conditional_flow( - x0, x1 - ) + t, xt, ut = self.conditional_flow_matcher.sample_location_and_conditional_flow(x0, x1) vt = self.cfm_net(t, xt) return (vt - ut).pow(2).mean(dim=-1) @@ -394,7 +393,7 @@ def get_loss(self): return_full_trajectory=True, no_grad=False, reverse_time=False, - time_range=self.time_range + time_range=self.time_range, )[-1] x_1, quad_reg = aug_output[..., :-1], aug_output[..., -1] prior_ll = self.prior.log_prob(x_1).mean() / (self.dim + 1) @@ -405,7 +404,6 @@ def get_loss(self): return pis_loss, prior_ll, sample_ll, quad_reg, term_loss - def generate_samples( self, sde, @@ -414,9 +412,7 @@ def generate_samples( diffusion_scale=1.0, ) -> torch.Tensor: num_samples = num_samples or self.num_samples_to_generate_per_epoch - samples = torch.zeros( - num_samples, self.dim + 1, device=self.device - ) + samples = torch.zeros(num_samples, self.dim + 1, device=self.device) return self.integrate( sde=sde, @@ -424,18 +420,18 @@ def generate_samples( reverse_time=False, return_full_trajectory=return_full_trajectory, diffusion_scale=diffusion_scale, - time_range=self.time_range + time_range=self.time_range, )[..., :-1] def integrate( self, - sde = None, + sde=None, samples: torch.Tensor = None, reverse_time=True, return_full_trajectory=False, diffusion_scale=1.0, no_grad=True, - time_range=1. + time_range=1.0, ) -> torch.Tensor: trajectory = integrate_sde( sde or self.pis_sde, @@ -444,7 +440,7 @@ def integrate( diffusion_scale=diffusion_scale, reverse_time=reverse_time, no_grad=no_grad, - time_range=time_range + time_range=time_range, ) if return_full_trajectory: return trajectory @@ -463,13 +459,13 @@ def compute_nll( ) num_integration_steps = self.num_integration_steps - if self.nll_integration_method == 'dopri5': + if self.nll_integration_method == "dopri5": num_integration_steps = 1 aug_output = cnf.integrate( aug_samples, num_integration_steps=num_integration_steps, - method=self.nll_integration_method + method=self.nll_integration_method, )[-1] x_1, logdetjac = aug_output[..., :-1], aug_output[..., -1] log_p_1 = prior.log_prob(x_1) @@ -504,9 +500,7 @@ def training_step(self, batch, batch_idx): cfm_loss = self.get_cfm_loss(cfm_samples) self.log_dict( - t_stratified_loss( - times, cfm_loss, loss_name="train/stratified/cfm_loss" - ) + t_stratified_loss(times, cfm_loss, loss_name="train/stratified/cfm_loss") ) cfm_loss = cfm_loss.mean() self.cfm_train_loss(cfm_loss) @@ -518,7 +512,7 @@ def training_step(self, batch, batch_idx): on_epoch=False, prog_bar=True, ) - + self.log_dict( { "train/pis_loss": self.pis_train_loss, @@ -527,15 +521,15 @@ def training_step(self, batch, batch_idx): "train/pis_reg_loss": self.pis_reg_loss, "train/pis_term_loss": self.pis_term_loss, }, - on_step=True, on_epoch=False, prog_bar=True + on_step=True, + on_epoch=False, + prog_bar=True, ) return loss def compute_log_z(self, cnf, prior, samples, prefix, name): - nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll( - cnf, prior, samples - ) + nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll(cnf, prior, samples) logz = self.energy_function(samples) + nll logz_metric = getattr(self, f"{prefix}_{name}logz") logz_metric.update(logz) @@ -549,9 +543,7 @@ def compute_log_z(self, cnf, prior, samples, prefix, name): def compute_and_log_nll(self, cnf, prior, samples, prefix, name): cnf.nfe = 0.0 - nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll( - cnf, prior, samples - ) + nll, forwards_samples, logdetjac, log_p_1 = self.compute_nll(cnf, prior, samples) nfe_metric = getattr(self, f"{prefix}_{name}nfe") nll_metric = getattr(self, f"{prefix}_{name}nll") logdetjac_metric = getattr(self, f"{prefix}_{name}nll_logdetjac") @@ -579,15 +571,22 @@ def compute_and_log_nll(self, cnf, prior, samples, prefix, name): return forwards_samples def pis_logZ(self): - times = torch.linspace(0., self.time_range, self.num_integration_steps + 1).to(self.device)[:-1] + times = torch.linspace(0.0, self.time_range, self.num_integration_steps + 1).to( + self.device + )[:-1] uw_term = 0 - dt = 1. / self.num_integration_steps + dt = 1.0 / self.num_integration_steps state = torch.zeros(self.eval_batch_size, self.dim + 1).to(self.device) for t in times: state, cur_uw_term = self.step_with_uw(t, state, dt) uw_term += cur_uw_term - - loss = state[:, -1] + uw_term + self.prior.log_prob(state[:, :-1]) - self.energy_function(state[:, :-1]) + + loss = ( + state[:, -1] + + uw_term + + self.prior.log_prob(state[:, :-1]) + - self.energy_function(state[:, :-1]) + ) log_weight = -loss + loss.mean() unnormal_weight = torch.exp(log_weight) @@ -601,7 +600,7 @@ def pis_logZ(self): logZ = torch.log(torch.mean(torch.exp(log_weight))) - loss.mean() return state[:, :-1], logZ_lb, logZ_ub, logZ_hb, logZ - + def gfn_log_Z(self): state, log_pf, log_pb = self.fwd_traj() log_r = self.energy_function(state) @@ -614,7 +613,7 @@ def gfn_log_Z(self): def get_elbo(self, data, prefix, name, num_evals=10): bsz = data.shape[0] data = data.view(bsz, 1, self.dim).repeat(1, num_evals, 1).view(bsz * num_evals, self.dim) - log_pf, log_pb = self.bwd_traj(data) + log_pf, log_pb = self.bwd_traj(data) log_weight = (log_pf - log_pb).view(bsz, num_evals) log_weight = logmeanexp(log_weight, dim=1) elbo_metric = getattr(self, f"{prefix}_{name}gfn_elbo") @@ -651,9 +650,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: labels. :param batch_idx: The index of the current batch. """ - batch = self.energy_function.sample_test_set( - self.num_samples_to_generate_per_epoch - ) + batch = self.energy_function.sample_test_set(self.num_samples_to_generate_per_epoch) loss = self.get_loss()[0] # update and log metrics @@ -661,18 +658,16 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: loss_metric(loss) if self.last_samples is None: - self.last_samples = self.generate_samples(self.pis_sde, diffusion_scale=self.diffusion_scale) + self.last_samples = self.generate_samples( + self.pis_sde, diffusion_scale=self.diffusion_scale + ) self.outputs[f"{prefix}/data"] = batch self.outputs[f"{prefix}/gen"] = self.last_samples - self.log( - f"{prefix}/loss", loss_metric, on_step=False, on_epoch=True, prog_bar=True - ) + self.log(f"{prefix}/loss", loss_metric, on_step=False, on_epoch=True, prog_bar=True) - batch = self.energy_function.sample_test_set( - self.eval_batch_size - ) + batch = self.energy_function.sample_test_set(self.eval_batch_size) self.get_elbo(batch, prefix, "") @@ -683,9 +678,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: self.outputs[f"{prefix}/cfm_prior"] = prior_samples if self.compute_nll_on_train_data: - train_samples = self.energy_function.sample_train_set( - self.eval_batch_size - ) + train_samples = self.energy_function.sample_train_set(self.eval_batch_size) _ = self.compute_and_log_nll( self.cfm_cnf, self.cfm_prior, train_samples, prefix, "train_" ) @@ -713,7 +706,7 @@ def fxn(t, x, args=None): ) num_integration_steps = self.num_integration_steps - if self.nll_integration_method == 'dopri5': + if self.nll_integration_method == "dopri5": num_integration_steps = 2 # noise = torch.randn(shape, device=self.device) * self.cfm_prior_std @@ -731,10 +724,8 @@ def fxn(t, x, args=None): print(e) print("Falling back on fixed-step integration") self.nfe = 0.0 - time = torch.linspace(0, 1, 1000+1, device=self.device) - return odeint( - reverse_wrapper(self.cfm_net), noise, t=time, method="euler" - )[-1] + time = torch.linspace(0, 1, 1000 + 1, device=self.device) + return odeint(reverse_wrapper(self.cfm_net), noise, t=time, method="euler")[-1] def scatter_prior(self, prefix, outputs): wandb_logger = get_wandb_logger(self.loggers) @@ -767,25 +758,35 @@ def eval_epoch_end(self, prefix: str): f"{prefix}/pis_logZ": pis_logZ, } - self.energy_function.log_samples(samples_gfn, wandb_logger, f"{prefix}_samples/gfn_samples") - self.energy_function.log_samples(samples_pis, wandb_logger, f"{prefix}_samples/pis_samples") - self.energy_function.log_samples(buffer_samples, wandb_logger, f"{prefix}_samples/buffer_samples") + self.energy_function.log_samples( + samples_gfn, wandb_logger, f"{prefix}_samples/gfn_samples" + ) + self.energy_function.log_samples( + samples_pis, wandb_logger, f"{prefix}_samples/pis_samples" + ) + self.energy_function.log_samples( + buffer_samples, wandb_logger, f"{prefix}_samples/buffer_samples" + ) if self.nll_with_cfm: # Generate data from the CFM # Calculate logZ based on that data cfm_samples = self.generate_cfm_samples(self.eval_batch_size) - self.compute_log_z( - self.cfm_cnf, self.cfm_prior, cfm_samples, prefix, "" - ) + self.compute_log_z(self.cfm_cnf, self.cfm_prior, cfm_samples, prefix, "") unprioritized_buffer_samples, _, _ = self.buffer.sample( self.eval_batch_size, prioritize=self.prioritize_cfm_training_samples, ) if self.energy_function.dimensionality == 2: self.scatter_prior(prefix + "_samples", self.outputs[f"{prefix}/cfm_prior"]) - self.energy_function.log_samples(cfm_samples, wandb_logger, f"{prefix}_samples/cfm_samples") - self.energy_function.log_samples(unprioritized_buffer_samples, wandb_logger, f"{prefix}_samples/unprioritized_buffer_samples") + self.energy_function.log_samples( + cfm_samples, wandb_logger, f"{prefix}_samples/cfm_samples" + ) + self.energy_function.log_samples( + unprioritized_buffer_samples, + wandb_logger, + f"{prefix}_samples/unprioritized_buffer_samples", + ) # pad with time dimension 1 names, dists = compute_distribution_distances( @@ -831,9 +832,7 @@ def setup(self, stage: str) -> None: self.cfm_net = torch.compile(self.cfm_net) if self.nll_with_cfm: - self.cfm_prior = self.partial_prior( - device=self.device, scale=self.cfm_prior_std - ) + self.cfm_prior = self.partial_prior(device=self.device, scale=self.cfm_prior_std) def configure_optimizers(self) -> Dict[str, Any]: """Choose what optimizers and learning-rate schedulers to use in your optimization. @@ -856,4 +855,4 @@ def configure_optimizers(self) -> Dict[str, Any]: "frequency": self.hparams.lr_scheduler_update_frequency, }, } - return {"optimizer": optimizer} \ No newline at end of file + return {"optimizer": optimizer} diff --git a/dem/train.py b/dem/train.py index 4cc0181..d82b7ca 100644 --- a/dem/train.py +++ b/dem/train.py @@ -60,9 +60,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: energy_function = hydra.utils.instantiate(cfg.energy) log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate( - cfg.model, energy_function=energy_function - ) + model: LightningModule = hydra.utils.instantiate(cfg.model, energy_function=energy_function) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) @@ -72,10 +70,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( - cfg.trainer, - callbacks=callbacks, - logger=logger, - num_sanity_val_steps=0 + cfg.trainer, callbacks=callbacks, logger=logger, num_sanity_val_steps=0 ) object_dict = { diff --git a/dem/utils/data_utils.py b/dem/utils/data_utils.py index 9949cc6..37b4f68 100644 --- a/dem/utils/data_utils.py +++ b/dem/utils/data_utils.py @@ -2,8 +2,7 @@ def remove_mean(samples, n_particles, n_dimensions): - """ - Makes a configuration of many particle system mean-free. + """Makes a configuration of many particle system mean-free. Parameters ---------- @@ -31,9 +30,7 @@ def interatomic_dist(samples): n_particles = samples.shape[-2] # Compute the pairwise differences and distances distances = samples[:, None, :, :] - samples[:, :, None, :] - distances = distances[ - :, torch.triu(torch.ones((n_particles, n_particles)), diagonal=1) == 1 - ] + distances = distances[:, torch.triu(torch.ones((n_particles, n_particles)), diagonal=1) == 1] dist = torch.linalg.norm(distances, dim=-1) diff --git a/dem/utils/logging_utils.py b/dem/utils/logging_utils.py index 5f41bb2..e0f088e 100644 --- a/dem/utils/logging_utils.py +++ b/dem/utils/logging_utils.py @@ -1,7 +1,6 @@ -import PIL - from typing import Any, Dict +import PIL from lightning_utilities.core.rank_zero import rank_zero_only from omegaconf import OmegaConf @@ -63,6 +62,4 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None: def fig_to_image(fig): fig.canvas.draw() - return PIL.Image.frombytes( - "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() - ) + return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) diff --git a/dem/utils/pylogger.py b/dem/utils/pylogger.py index 31a76c3..c4ee867 100644 --- a/dem/utils/pylogger.py +++ b/dem/utils/pylogger.py @@ -24,9 +24,7 @@ def __init__( super().__init__(logger=logger, extra=extra) self.rank_zero_only = rank_zero_only - def log( - self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs - ) -> None: + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: """Delegate a log call to the underlying logger, after prefixing its message with the rank of the process it's being logged from. If `'rank'` is provided, then the log will only occur on that rank/process. @@ -41,9 +39,7 @@ def log( msg, kwargs = self.process(msg, kwargs) current_rank = getattr(rank_zero_only, "rank", None) if current_rank is None: - raise RuntimeError( - "The `rank_zero_only.rank` needs to be set before use" - ) + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") msg = rank_prefixed_message(msg, current_rank) if self.rank_zero_only: if current_rank == 0: diff --git a/dem/utils/utils.py b/dem/utils/utils.py index c4d02e7..101ffcd 100644 --- a/dem/utils/utils.py +++ b/dem/utils/utils.py @@ -95,9 +95,7 @@ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: return wrap -def get_metric_value( - metric_dict: Dict[str, Any], metric_name: Optional[str] -) -> Optional[float]: +def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: """Safely retrieves value of the metric logged in LightningModule. :param metric_dict: A dict containing metric values. diff --git a/environment.yaml b/environment.yaml index 1ada529..d9829db 100644 --- a/environment.yaml +++ b/environment.yaml @@ -62,4 +62,3 @@ dependencies: - scikit-learn - scipy - -e . -