Skip to content

Commit

Permalink
first format pass
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 committed Feb 8, 2024
1 parent 5f0c793 commit 9f1dcf2
Show file tree
Hide file tree
Showing 52 changed files with 411 additions and 631 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
TODO: FIX PREPRINT BUTTON AFTER WE'RE ON ARXIV!!!

## Description
This is the official repository for the paper [Iterated Denoising Energy Matching for Sampling from Boltzmann Densities](https://arxiv.org/abs/2310.02391) (TODO: FIX THIS LINK AFTER WE'RE ON ARXIV).

This is the official repository for the paper [Iterated Denoising Energy Matching for Sampling from Boltzmann Densities](https://arxiv.org/abs/2310.02391) (TODO: FIX THIS LINK AFTER WE'RE ON ARXIV).

We propose iDEM, a scalable and efficient method to sample from unnormalized probability distributions. iDEM makes use of the DEM objective, inspired by the stochastic regression and simulation
free principles of score and flow matching objectives while allowing one to learn off-policy, in a loop while itself generating (optionally exploratory) new states which are subsequently
learned on. iDEM is also capable of incorporating symmetries, namely those represented by the product group of $SE(3) \times \mathbb{S}_n$. We experiment on a 2D GMM task as well as a number of physics
learned on. iDEM is also capable of incorporating symmetries, namely those represented by the product group of $SE(3) \\times \\mathbb{S}\_n$. We experiment on a 2D GMM task as well as a number of physics
inspired problems. These include:

- DW4 -- the 4 particle double well potential (8 dimensions total)
Expand All @@ -29,13 +30,14 @@ out most of the code and experiments with help from [@sarthmit](https://github.c

For installation we recommend the use of Micromamba. Please refer [here](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html) for an installation guide for Micromamba.
First, we install dependencies

```bash
# clone project
git clone [email protected]:jarridrb/DEM.git
cd DEM

# create micromamba environment
micromamba create -f environment.yaml
micromamba create -f environment.yaml
micromamba activate dem

# install requirements
Expand All @@ -48,25 +50,29 @@ an example `.env.example` file for convenience. Note that to use wandb we requir
`.env` file.

To run an experiment, e.g., GMM with iDEM, you can run on the command line

```bash
python dem/train.py experiment=gmm_idem
```

We include configs for all experiments matching the settings we used in our paper for both iDEM and pDEM with the exception of LJ55 for
which we only include a config for iDEM and pDEM had convergence issues on LJ55.

## Current Code
## Current Code

The current repository contains code for experiments for iDEM and pDEM as specified in our paper.

## Planned Updates

- [ ] Code to do Langevin on top of generated samples

## Citations

If this codebase is useful towards other research efforts please consider citing us. TODO: FIX THIS CITATION ONCE WE'RE ON ARXIV!!!

```
@misc{bose2023se3stochastic,
title={SE(3)-Stochastic Flow Matching for Protein Backbone Generation},
title={SE(3)-Stochastic Flow Matching for Protein Backbone Generation},
author={Avishek Joey Bose and Tara Akhound-Sadegh and Kilian Fatras and Guillaume Huguet and Jarrid Rector-Brooks and Cheng-Hao Liu and Andrei Cristian Nica and Maksym Korablyov and Michael Bronstein and Alexander Tong},
year={2023},
eprint={2310.02391},
Expand All @@ -75,14 +81,13 @@ If this codebase is useful towards other research efforts please consider citing
}
```


## Contribute

We welcome issues and pull requests (especially bug fixes) and contributions.
We will try our best to improve readability and answer questions!


## Licences

This repo is licensed under the [MIT License](https://opensource.org/license/mit/).

### Warning: the current code uses PyTorch 2.0.0+
Expand Down
24 changes: 11 additions & 13 deletions configs/energy/dw4.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
_target_: dem.energies.multi_double_well_energy.MultiDoubleWellEnergy
_target_: dem.energies.multi_double_well_energy.MultiDoubleWellEnergy

dimensionality: 8
n_particles: 4
dimensionality: 8
n_particles: 4

data_from_efm: true
data_path: "data/test_split_DW4.npy"
data_path_train: "data/train_split_DW4.npy"
data_path_val: "data/val_split_DW4.npy"
data_from_efm: true
data_path: "data/test_split_DW4.npy"
data_path_train: "data/train_split_DW4.npy"
data_path_val: "data/val_split_DW4.npy"

device: ${trainer.accelerator}
device: ${trainer.accelerator}

plot_samples_epoch_period: 1
plot_samples_epoch_period: 1

data_normalization_factor: 1.0
data_normalization_factor: 1.0

is_molecule: true


is_molecule: true
18 changes: 9 additions & 9 deletions configs/energy/gmm.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
_target_: dem.energies.gmm_energy.GMM
_target_: dem.energies.gmm_energy.GMM

dimensionality: 2
n_mixes: 40
loc_scaling: 40
log_var_scaling: 1.0
dimensionality: 2
n_mixes: 40
loc_scaling: 40
log_var_scaling: 1.0

device: ${trainer.accelerator}
device: ${trainer.accelerator}

plot_samples_epoch_period: 1
plot_samples_epoch_period: 1

should_unnormalize: true
data_normalization_factor: 50
should_unnormalize: true
data_normalization_factor: 50
20 changes: 10 additions & 10 deletions configs/energy/lj13.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
_target_: dem.energies.lennardjones_energy.LennardJonesEnergy
_target_: dem.energies.lennardjones_energy.LennardJonesEnergy

dimensionality: 39
n_particles: 13
data_path: "data/test_split_LJ13-1000.npy"
data_path_train: "data/train_split_LJ13-1000.npy"
data_path_val: "data/test_split_LJ13-1000.npy"
dimensionality: 39
n_particles: 13
data_path: "data/test_split_LJ13-1000.npy"
data_path_train: "data/train_split_LJ13-1000.npy"
data_path_val: "data/test_split_LJ13-1000.npy"

device: ${trainer.accelerator}
device: ${trainer.accelerator}

plot_samples_epoch_period: 1
plot_samples_epoch_period: 1

data_normalization_factor: 1.0
data_normalization_factor: 1.0

is_molecule: True
is_molecule: True
23 changes: 10 additions & 13 deletions configs/energy/lj55.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
_target_: dem.energies.lennardjones_energy.LennardJonesEnergy
_target_: dem.energies.lennardjones_energy.LennardJonesEnergy

dimensionality: 165
n_particles: 55
data_path: "data/test_split_LJ55-1000-part1.npy"
data_path_train: "data/train_split_LJ55-1000-part1.npy"
data_path_val: "data/val_split_LJ55-1000-part1.npy"
dimensionality: 165
n_particles: 55
data_path: "data/test_split_LJ55-1000-part1.npy"
data_path_train: "data/train_split_LJ55-1000-part1.npy"
data_path_val: "data/val_split_LJ55-1000-part1.npy"

device: ${trainer.accelerator}
device: ${trainer.accelerator}

plot_samples_epoch_period: 1
plot_samples_epoch_period: 1

data_normalization_factor: 1.0

is_molecule: True


data_normalization_factor: 1.0

is_molecule: True
10 changes: 5 additions & 5 deletions configs/experiment/dw4_idem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ logger:
tags: ${tags}
group: "dw4_efm"

defaults:
defaults:
- override /energy: dw4
- override /model/net: egnn

Expand All @@ -31,8 +31,8 @@ model:

noise_schedule:
_target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
sigma_min: 0.00001
sigma_max: 3
sigma_min: 0.00001
sigma_max: 3

partial_prior:
_target_: dem.energies.base_prior.MeanFreePrior
Expand All @@ -48,13 +48,13 @@ model:
_target_: dem.models.components.clipper.Clipper
should_clip_scores: True
should_clip_log_rewards: False
max_score_norm: 20
max_score_norm: 20
min_log_reward: null

# num_samples_to_sample_from_buffer: 5120
diffusion_scale: 1
num_samples_to_generate_per_epoch: 1000

init_from_prior: true

eval_batch_size: 1000
Expand Down
10 changes: 5 additions & 5 deletions configs/experiment/dw4_pdem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ logger:
tags: ${tags}
group: "dw4_efm"

defaults:
defaults:
- override /energy: dw4
- override /model/net: egnn

Expand All @@ -31,8 +31,8 @@ model:

noise_schedule:
_target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
sigma_min: 0.00001
sigma_max: 3
sigma_min: 0.00001
sigma_max: 3

partial_prior:
_target_: dem.energies.base_prior.MeanFreePrior
Expand All @@ -48,13 +48,13 @@ model:
_target_: dem.models.components.clipper.Clipper
should_clip_scores: True
should_clip_log_rewards: False
max_score_norm: 20
max_score_norm: 20
min_log_reward: null

# num_samples_to_sample_from_buffer: 5120
diffusion_scale: 1
num_samples_to_generate_per_epoch: 1000

init_from_prior: true

eval_batch_size: 1000
Expand Down
4 changes: 1 addition & 3 deletions configs/experiment/gmm_idem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
defaults:
defaults:
- override /energy: gmm


tags: ["GMM", "iDEM"]

seed: 12345
Expand Down Expand Up @@ -60,4 +59,3 @@ model:

trainer:
check_val_every_n_epoch: 5

3 changes: 1 addition & 2 deletions configs/experiment/gmm_pdem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
defaults:
defaults:
- override /energy: gmm


tags: ["GMM", "pDEM"]

seed: 12345
Expand Down
6 changes: 3 additions & 3 deletions configs/experiment/lj13_idem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ logger:
tags: ${tags}
group: "lj13"

defaults:
defaults:
- override /energy: lj13
- override /model/net: egnn

Expand All @@ -27,7 +27,7 @@ model:
noise_schedule:
_target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
sigma_min: 0.01
sigma_max: 2
sigma_max: 2

partial_prior:
_target_: dem.energies.base_prior.MeanFreePrior
Expand All @@ -50,7 +50,7 @@ model:
diffusion_scale: 0.9
num_samples_to_generate_per_epoch: 1000
num_samples_to_sample_from_buffer: 512

init_from_prior: true

cfm_prior_std: 2
Expand Down
6 changes: 3 additions & 3 deletions configs/experiment/lj13_pdem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ logger:
tags: ${tags}
group: "lj13"

defaults:
defaults:
- override /energy: lj13
- override /model/net: egnn

Expand All @@ -27,7 +27,7 @@ model:
noise_schedule:
_target_: dem.models.components.noise_schedules.GeometricNoiseSchedule
sigma_min: 0.01
sigma_max: 2
sigma_max: 2

partial_prior:
_target_: dem.energies.base_prior.MeanFreePrior
Expand All @@ -50,7 +50,7 @@ model:
diffusion_scale: 0.9
num_samples_to_generate_per_epoch: 1000
num_samples_to_sample_from_buffer: 512

init_from_prior: true

cfm_prior_std: 2
Expand Down
4 changes: 1 addition & 3 deletions configs/experiment/lj55_idem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ logger:
tags: ${tags}
group: "lj55"

defaults:
defaults:
- override /energy: lj55
- override /model/net: egnn

Expand Down Expand Up @@ -66,5 +66,3 @@ model:
logz_with_cfm: true

nll_integration_method: dopri5


5 changes: 2 additions & 3 deletions configs/model/dem.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ scheduler:

defaults:
- net:
- mlp
- mlp
- noise_schedule:
- geometric
- geometric

buffer:
_target_: dem.models.components.prioritised_replay_buffer.SimpleBuffer
Expand Down Expand Up @@ -79,4 +79,3 @@ logz_with_cfm: false

num_samples_to_save: 100000
tol: 1e-5

2 changes: 1 addition & 1 deletion configs/model/net/egnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ n_particles: 13
n_dimension: 3
hidden_nf: 32
n_layers: 3
act_fn:
act_fn:
_target_: torch.nn.SiLU
recurrent: True
tanh: True
Expand Down
2 changes: 1 addition & 1 deletion configs/model/net/pis_mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ _partial_: true
num_layers: 2
channels: 64
in_shape: ${energy.dimensionality}
out_shape: ${energy.dimensionality}
out_shape: ${energy.dimensionality}
Loading

0 comments on commit 9f1dcf2

Please sign in to comment.