Learning to Learn with Generative Models of Neural Network Checkpoints
Official PyTorch Implementation
This repo contains training, evaluation, and visualization code for our recent paper exploring loss-conditional diffusion models of neural network parameters.
Learning to Learn with Generative Models of Neural Network Checkpoints
William Peebles*, Ilija Radosavovic*, Tim Brooks, Alexei A. Efros, Jitendra Malik
University of California, Berkeley
Our generative models are conditioned on a starting parameter vector, a starting loss/error/return and a prompted loss/error/return. With these inputs, we can sample an updated parameter vector that ideally achieves the prompt. We call our model G.pt (G and .pt refer to generative models and checkpoint extensions, respectively). The core of G.pt is a transformer model that operates over sequences of parameters from the input neural network parameters. Similar to ViTs, G.pt leverages very few domain-specific inductive biases (only in tokenization and data augmentation). The transformer is trained as a diffusion model directly in parameter space. After training, G.pt can optimize neural networks from random initialization in one step by prompting for a small loss/error or high return. In this paper, we introduce G.pt models for optimizing MNIST MLPs, CIFAR-10 CNNs and Cartpole Gaussian MLPs.
This repository contains:
- ⚡️ Five pre-trained G.pt DDPM Transformers for vision and RL tasks
- 🪐 A dataset containing over 23M neural net checkpoints across 100K+ training runs
- 💥 Training and testing scripts for G.pt models
First, download and set up the repo:
git clone https://github.com/wpeebles/G.pt.git
cd G.pt
pip install -e .
We provide an environment.yml
file that can be used to create a Conda environment:
conda env create -f environment.yml
conda activate G.pt
If you opt to use your own environment, you'll need Python 3.8 in order to run the IsaacGym RL simulator (newer versions of Python may have compatibility issues). You'll also want to set up a Weights & Biases account since some of our visualization code uses it.
Finally, in order to train or evaluate RL G.pt models, you'll need to install IsaacGym. Download it here, then install it:
cd /path/to/isaac-gym/python
pip install -e .
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/anaconda3/envs/G.pt/lib
We provide three checkpoint datasets, in aggregate containing over 23M checkpoints from 100K+ training runs. Each individual checkpoint contains neural network parameters and any useful task-specific metadata (e.g., test losses and errors for classification, episode returns for RL). If you run our G.pt testing scripts (explained below), the relevant checkpoint data will be auto-downloaded. Or, you can download all three checkpoint datasets (and the five pre-trained G.pt models) by running:
python Gpt/download.py
This will store all three datasets in a folder named checkpoint_datasets
. The breakdown for each dataset is
as follows:
Checkpoint Dataset | # Checkpoints | # Runs | # Checkpoints/Run | Storage (GB) |
---|---|---|---|---|
MNIST MLPs | 2.1M | 10728 | 200 | 68 |
CIFAR-10 CNNs | 11.3M | 56840 | 200 | 275 |
Cartpole Gaussian MLPs | 10M | 50026 | 200 | 82 |
Additional information about the datasets can be found in our paper.
Using our pre-trained G.pt models. We provide five pre-trained G.pt models. They can be used by setting the config's
resume_path
to one of cartpole.pt
, mnist_loss.pt
, mnist_error.pt
, cifar_loss.pt
or cifar_error.pt
. If you use
one of these values, the relevant model will be automatically downloaded and cached in the pretrained_models
folder.
You can evaluate G.pt models by running main.py
with the test_only=True
flag. We provide several
testing configs here. For example, to create one-step/recursive optimization curves and compute prompt
alignment scores for our loss-conditional MNIST model, you can run:
python main.py --config-path configs/test --config-name mnist_loss.yaml num_gpus=N
where N
is the number of GPUs to distribute the evaluation across. Running this command will generate plots in
Weights & Biases. There are several options in the configs you can change.
We also include a playground.py
script which gives a more minimal example of loading and running
G.pt models on one GPU. It can be used to generate latent walks through parameter space for our CIFAR-10 G.pt models:
python playground.py --config-name cifar_error.yaml
The main entrypoint for training new G.pt models is main.py
. You can find our training configs
here. Example usage to train a G.pt model using N
GPUs:
python main.py --config-name mnist_loss.yaml num_gpus=N
To add a new task, you need to update the TASK_METADATA
dictionary in tasks.py
with a new entry.
You'll need to come up with a name for the task which will be the new key. The corresponding value should be a
dictionary with the following items:
(1) task_test_fn
,
a function mapping any needed inputs and an nn.Module
instance to a loss/error/return/etc. Your function should
explicitly take as input anything that is expensive to construct multiple times (e.g., datasets, dataloaders, simulators, etc.).
You will specify any of these expensive inputs via data_fn
below. Make sure the nn.Module
is the last input to this function.
(2) constructor
, a function
that constructs a randomly-initialized nn.Module
with the correct architecture. This constructor needs to produce the correct architecture when called without any arguments.
(3) data_fn
, a function that outputs
any inputs needed to call task_test_fn
(besides the input nn.Module
) which should be cached. For example, data and
environment instances should be returned by data_fn
, and they will then be automatically passed to task_test_fn
whenever
it is invoked. This avoids expensive re-instantiation of these objects everytime we want to call task_test_fn
(which is usually a lot of times).
data_fn
must always return a tuple or list, even if it returns only one thing or nothing.
(4) minimize
, a boolean that indicates if the goal is to minimize or maximize the output of task_test_fn
.
(5) best_prompt
, a float representing the "best" loss/error/return/etc. you want to prompt G.pt with for
one-step optimization.
(6) recursive_prompt
, a float representing the loss/error/return/etc. you want to repeatedly prompt G.pt with when
performing recursive optimization.
(Optional, but recommended) You can also include an aug_fn
key that maps to a function that performs a loss-preserving
augmentation on the neural network parameters directly.
Finally, make sure you pass the name of your new task via dataset.name
.
We provide several scripts to facilitate checkpoint generation in the data_gen
folder. We include example
files that generate supervised learning and reinforcement learning checkpoints. You can use the
train_batch.py
script to indefinitely launch single-GPU task-level training jobs on a node
to collect a large number of checkpoints. After saving enough checkpoints, you'll want to filter out any with bad parameter
values (e.g., NaNs). You can do this with prepare_checkpoints.py
(be sure to update
it with a path to your new directory of checkpoints):
python Gpt/data/prepare_checkpoints.py
prepare_checkpoints.py
will also compute the variance over parameter values in your dataset, which you can pass via
dataset.openai_coeff
in your training/testing config file in order to normalize the parameters before being
processed by G.pt.
@article{Peebles2022,
title={Learning to Learn with Generative Models of Neural Network Checkpoints},
author={William Peebles and Ilija Radosavovic and Tim Brooks and Alexei Efros and Jitendra Malik},
year={2022},
journal={arXiv preprint arXiv:2209.12892},
}
We thank Anastasios Angelopoulos, Shubham Goel, Allan Jabri, Michael Janner, Assaf Shocher, Aravind Srinivas, Matthew Tancik, Tete Xiao, Saining Xie and Jun-Yan Zhu for helpful discussions. William Peebles and Tim Brooks are supported by the NSF Graduate Research Fellowship. Additional support provided by the DARPA program on Machine Common Sense.
This codebase borrows from OpenAI's diffusion repos, most notably ADM, and Andrej Karpathy's minGPT. We thank the authors for open-sourcing their work.