This is the code repository for the paper Amortized Inference for Causal Structure Learning (Lorch et al., 2022, NeurIPS 2022). Performing amortized variational inference for causal discovery (AVICI) allows inferring causal structure from data based on a simulator of the domain of interest. By training a neural network to infer structure from the simulated data, it can acquire realistic inductive biases from prior knowledge that is hard to cast as score functions or conditional independence tests.
To install the latest stable release, run:
pip install avici
The package allows training new models from scratch on custom data-generating processes and performing predictions with pretrained models from our side. The codebase is written in Python and JAX.
Using the avici
package is as easy as running the following code
snippet:
import avici
from avici import simulate_data
# g: [d, d] causal graph of `d` variables
# x: [n, d] data matrix containing `n` observations of the `d` variables
g, x, _ = simulate_data(d=50, n=200, domain="rff-gauss")
# load pretrained model
model = avici.load_pretrained(download="scm-v0")
# g_prob: [d, d] predicted edge probabilities of the causal graph
g_prob = model(x=x)
You can run a working example this snippet directly in the following Google Colab notebook:
The above code automatically downloads and initializes a pretrained model checkpoint (~60MB) of the domain and predicts the causal structure underlying the simulated data.
We currently provide the following models checkpoints,
which can be specified by the download
argument:
scm-v0
: linear and nonlinear SCM data, broad graph and noise distributionsneurips-linear
: SCM data with linear causal mechanismsneurips-rff
: SCM data with nonlinear causal mechanisms drawn from GPs with squared-exponential kernel (defined via random Fourier features)neurips-grn
: Synthetic scRNA-seq gene expression data using the SERGIO simulator by Dibaeinia and Sinha, (2020)
We recommend the latest scm-v0
for working with arbitrary real-valued data.
This model was trained on SCM data simulated from a large variety of graph models with up to 100 nodes,
both linear and nonlinear causal mechanisms, and homogeneous and heterogeneous additive noise from
Gaussian, Laplace, and Cauchy distributions.
The models neurips-linear
, neurips-rff
, neurips-grn
studied in our original
paper were purposely trained on narrower training distributions to assess the out-of-distribution
capability of AVICI. Unless your prior domain knowledge is strong,
this may make the neurips-*
models less suitable for benchmarking
or as general purpose/out-of-the-box tools in your application.
The training distribution of scm-v0
essentially combines those of
neurips-linear
and neurips-rff
as well as their out-of-distribution
settings in Lorch et al., (2022).
For details on the exact training distributions of these models,
please refer to the model cards
on HuggingFace. Appendix A of
Lorch et al., (2022) also defines the training distributions
of the neurips-*
models.
The YAML domain config file for each model is available in avici/config/train/
.
Calling model
as obtained from avici.load_pretrained
predicts the [d, d]
matrix of probabilities for each possible edge in the causal graph
and accepts the following arguments:
- x (ndarray) – Real-valued data matrix of shape
[n, d]
- interv (ndarray, optional) – Binary matrix of the same shape as x
with interv[i,j] = 1 iff node j was intervened upon in
observation i. (Default is
None
) - return_probs (bool, optional) – Whether to return probability estimates
for each edge.
False
simply clips the predictions to 0 and 1 using a decision threshold of 0.5. (Default isTrue
as the computational cost is the same.)
When sampling synthetic data via avici.simulate_data
,
the following domain specifiers (dataset distributions)
are currently provided:
lin-gauss
,
lin-gauss-heterosked
,
lin-laplace-cauchy
,
rff-gauss
,
rff-gauss-heterosked
,
rff-laplace-cauchy
,
gene-ecoli
,
but custom config files can be specified, too.
All these domains are defined inside avici.config.examples
.
In the example-custom folder, we provide an extended README together with a corresponding implementation that illustrates a detailed example of how to train an AVICI model for a custom data-generating process.
In short, the following three components are needed for training a full model:
-
func.py
: (Optional) Python file defining custom data-generating processesIf you would like to train on data-generating processes not already provided by
avici.synthetic
, this file implements subclasses ofGraphModel
andMechanismModel
doing so. -
domain.yaml
: YAML file defining the training data distributionThis configuration file specifies the full distribution over datasets used for training. Several graph models and data-generating mechanisms are available out-of-the-box, so providing additional modules via
func.py
is optional. This file can also be used to simulate data inavici.simulate_data
. -
train.py
: Python training scriptFully-fledged training script for multi-device training (if available) based on the above configurations.
The checkpoints created using the training script can directly be loaded by the avici.load_pretrained
function from above:
import avici
model = avici.load_pretrained(checkpoint_dir="path/to/checkpoint", expects_counts=False)
When using avici
for your research and applications, we recommend using
the easy-to-use main
branch and installing the latest stable
release using PyPI's pip
as explained above.
For custom installations, we recommend using conda
and generating
a new environment via
conda env create --file environment.yaml
You then need to install the avici
package with
pip install -e .
Installing the package by first setting up a conda environment
using our conda environment.yaml
config and then installing
pip install -r requirements.txt
before finally running
pip install -e .
works on Apple M1 MacBooks.
Directly installing avici
via PyPI may install incompatible versions
or builds of package requirements, which may cause unexpected, low-level errors.
In addition to main
, this repository also contains a full
branch,
which contains
comprehensive code for reproducing the the experimental results in
Lorch et al., (2022).
The purpose of full
is reproducibility; the branch is not
updated anymore and may contain outdated notation and documentation.
@article{lorch2022amortized,
title={Amortized Inference for Causal Structure Learning},
author={Lorch, Lars and Sussex, Scott and Rothfuss, Jonas and Krause, Andreas and Sch{\"o}lkopf, Bernhard},
journal={Advances in Neural Information Processing Systems},
volume={35},
year={2022}
}