Welcome to our public repository, implementing the learning algorithm from our NeurIPS 2019 paper, Extracting Automata from Recurrent Neural Networks Using Queries and Counterexamples.
Run the full_example.py
code for a full example, which will:
- Train an RNN on given samples (choose the argument
--spice-example
to run on problem set 0 from the SPiCe competition, or--uhl-num
to select a UHL (options: 1,2,3). - Extract from that RNN to make a PDFA (using weighted Lstar), WFAs (using the spectral algorithm), and n-grams.
- Evaluate the WER and NDCG against the RNN for each of the extracted models
- Save the RNN and extracted models, print the training and extraction times and measure results, and draw the PDFAs, all in a new folder
results/[lang]_[timestamp]
.
Example runs:
python3 full_example.py --spice-example
python3 full_example.py --uhl-number=2
You can also set the parameters for all the extractions and measures, e.g.:
python3 full_example.py --spice-example --RNNClass=GRU --nPS=100 --lstar-time-limit=50 --ngram-total-sample-length=10000 --ndcg-k=3
More parameters are listed in full_example.py
.
Everything here is implemented in Python 3. To use these notebooks, you will also need to install:
- Pytorch
- Graphviz (for drawing the extracted PDFAs)
- NumPy and SciPy
- Matplotlib (for network printouts during training)
If you are on a mac using Homebrew, then NumPy, SciPy, Scikit-Learn, Matplotlib, Graphviz and Jupyter should all hopefully
work with brew install numpy
, brew install scipy
, etc.
If you don't have Homebrew, or wherever brew install
doesn't work, try pip install
instead.
For Graphviz you may first need to download and install the package yourself (Graphviz),
after which you can run pip install graphviz
.
If you're lucky, brew install graphviz
might take care of all of this for you by itself.
You can apply the full example directly to any language model (eg RNN, Transformer, other..) that provides the following API:
input_alphabet,end_token,internal_alphabet
: attributes listing the possible input tokens and the end tokeninitial_state
: function getting the model's initial statenext_state
: function with two parameters: the model's current states1
and current input tokent
, that returns a new states2
without modifyings1
.state_probs_dist
: function with single parameter: a model state, that returns the state's next-token distribution in the order of itsinternal_alphabet
attribute. (e.g., so that the probability of stopping after states
ismodel.state_probs_dist(s)[model.internal_alphabet.index(model.end_token)]
).state_char_prob
: function with two parameters: a model states
and internal tokent
, equivalent to evaluatingmodel.state_probs_dist(s)[model.internal_alphabet.index(t)]
. (Here because for some language models, this function might have a faster implementation than actually calculating the entire distribution).
For efficiency, it is also possible to implement directly the functions weight
and weights_for_sequences_with_same_prefix
as described in LanguageModel.py
, which may be faster and more accurate when done directly in the model (as opposed to through the LanguageModel wrapper, which will apply the next_state
and state_char_prob
functions several times to compute them).
In particular for the spectral extraction, implementing the function weights_for_sequences_with_same_prefix
directly in the RNN using batching can speed up the reconstruction.
You can cite this work using:
@incollection{NIPS2019_9062,
title = {Learning Deterministic Weighted Automata with Queries and Counterexamples},
author = {Weiss, Gail and Goldberg, Yoav and Yahav, Eran},
booktitle = {Advances in Neural Information Processing Systems 32},
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch'{e}-Buc and E. Fox and R. Garnett},
pages = {8558--8569},
year = {2019},
publisher = {Curran Associates, Inc.},
}
This repository contains a sample train file from the SPiCe (Sequence Prediction Challenge) competition,
more samples can be obtained and played with on the website.
We also use Rémi Eyraud's NDCG evaluation code from the same challenge to compute NDCG on our extracted models, though we change it slightly to allow the use of different values of k
.
The ALERGIA comparison in the paper was done using FLEXFRINGE.