Skip to content

Commit

Permalink
Added filter_spec, disabled dynamic tracing, added docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 2, 2024
1 parent 1854c76 commit 8bbef0a
Show file tree
Hide file tree
Showing 19 changed files with 1,607 additions and 212 deletions.
51 changes: 19 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Quax

Uses JAX's nonstandard interpretation to perform multiple dispatch on custom array-ish objects, like:
JAX + multiple dispatch + custom array-ish objects, e.g.:

- LoRA weight matrices
- symbolic zeros
Expand All @@ -11,69 +11,56 @@ Uses JAX's nonstandard interpretation to perform multiple dispatch on custom arr
- arrays with physical units attached
- etc! (See the built-in `quax.examples` library for most of the above!)

This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects.
For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.

_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_
This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

This means that it works even with existing programs, that were not written to accept such array-ish objects: just wrap the program in the `quax.quaxify` transform.
_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_

## Installation

```
pip install quax
```

## Documentation

Available at https://docs.kidger.site/quax.

## Example: LoRA

```python
import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

# Start off with any JAX program: here, the forward pass through a linear layer.
key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

# Make some of the inputs be an array-ish object. This function finds all
# `eqx.nn.Linear` layers, and wraps their weights in `LoraArray`s.
lora_linear = lora.loraify(linear, rank=2, key=key3)
# For this simple model, we could also do it manually.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)

# Wrap your function call in quaxify. This transform calls your original function,
# whilst looking up any multiple dispatch rules registered for any custom array-ish
# objects.
out = quax.quaxify(lora_linear)(vector)
```
This example demonstrates everything you need to use the built-in `quax.examples.lora` library.

--8<-- ".lora-example.md"

## Work in progress!

This library is a work in progress! Right now it should support enough to run LoRA on common models. However, some operations (e.g. `jax.lax.cond`) are not yet supported. If you attempt to use these then an error will be thrown whilst tracing your program.
Right now, the following are not supported:

- Control flow primitives (e.g. `jax.lax.cond`).
- `jax.custom_vjp`

If you find yourself hitting any of these, then go ahead and open an issue, and/or a pull request!
It should be fairly straightforward to add support for these; open an issue or pull request.

## See also: other libraries in the JAX ecosystem

[Equinox](https://github.com/patrick-kidger/equinox): neural networks.

[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.

[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.

[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.

[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.

[Lineax](https://github.com/google/lineax): linear solvers.
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.

[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.

[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).

[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.

[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.

Expand Down
1 change: 1 addition & 0 deletions docs/.htaccess
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ErrorDocument 404 /equinox/404.html
34 changes: 34 additions & 0 deletions docs/.lora-example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
```python
import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
return model(x)

run(linear, vector) # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)
```
20 changes: 20 additions & 0 deletions docs/_overrides/partials/source.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{% import "partials/language.html" as lang with context %}
<a href="{{ config.repo_url }}" title="{{ lang.t('source.link.title') }}" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
{% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %}
{% include ".icons/" ~ icon ~ ".svg" %}
</div>
<div class="md-source__repository">
{{ config.repo_name }}
</div>
</a>
{% if config.theme.twitter_url %}
<a href="{{ config.theme.twitter_url }}" title="Go to Twitter" class="md-source">
<div class="md-source__icon md-icon">
{% include ".icons/fontawesome/brands/twitter.svg" %}
</div>
<div class="md-source__repository">
{{ config.theme.twitter_name }}
</div>
</a>
{% endif %}
2 changes: 2 additions & 0 deletions docs/_static/.README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The favicon is `math-integral` from https://materialdesignicons.com, found by way of https://pictogrammers.com.
(The logo is `math-integral-box`.)
167 changes: 167 additions & 0 deletions docs/_static/custom_css.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/* Fix /page#foo going to the top of the viewport and being hidden by the navbar */
html {
scroll-padding-top: 50px;
}

/* Fit the Twitter handle alongside the GitHub one in the top right. */

div.md-header__source {
width: revert;
max-width: revert;
}

a.md-source {
display: inline-block;
}

.md-source__repository {
max-width: 100%;
}

/* Emphasise sections of nav on left hand side */

nav.md-nav {
padding-left: 5px;
}

nav.md-nav--secondary {
border-left: revert !important;
}

.md-nav__title {
font-size: 0.9rem;
}

.md-nav__item--section > .md-nav__link {
font-size: 0.9rem;
}

/* Indent autogenerated documentation */

div.doc-contents {
padding-left: 25px;
border-left: 4px solid rgba(230, 230, 230);
}

/* Increase visibility of splitters "---" */

[data-md-color-scheme="default"] .md-typeset hr {
border-bottom-color: rgb(0, 0, 0);
border-bottom-width: 1pt;
}

[data-md-color-scheme="slate"] .md-typeset hr {
border-bottom-color: rgb(230, 230, 230);
}

/* More space at the bottom of the page */

.md-main__inner {
margin-bottom: 1.5rem;
}

/* Remove prev/next footer buttons */

.md-footer__inner {
display: none;
}

/* Change font sizes */

html {
/* Decrease font size for overall webpage
Down from 137.5% which is the Material default */
font-size: 110%;
}

.md-typeset .admonition {
/* Increase font size in admonitions */
font-size: 100% !important;
}

.md-typeset details {
/* Increase font size in details */
font-size: 100% !important;
}

.md-typeset h1 {
font-size: 1.6rem;
}

.md-typeset h2 {
font-size: 1.5rem;
}

.md-typeset h3 {
font-size: 1.3rem;
}

.md-typeset h4 {
font-size: 1.1rem;
}

.md-typeset h5 {
font-size: 0.9rem;
}

.md-typeset h6 {
font-size: 0.8rem;
}

/* Bugfix: remove the superfluous parts generated when doing:
??? Blah
::: library.something
*/

.md-typeset details .mkdocstrings > h4 {
display: none;
}

.md-typeset details .mkdocstrings > h5 {
display: none;
}

/* Change default colours for <a> tags */

[data-md-color-scheme="default"] {
--md-typeset-a-color: rgb(0, 189, 164) !important;
}
[data-md-color-scheme="slate"] {
--md-typeset-a-color: rgb(0, 189, 164) !important;
}

/* Highlight functions, classes etc. type signatures. Really helps to make clear where
one item ends and another begins. */

[data-md-color-scheme="default"] {
--doc-heading-color: #DDD;
--doc-heading-border-color: #CCC;
--doc-heading-color-alt: #F0F0F0;
}
[data-md-color-scheme="slate"] {
--doc-heading-color: rgb(25,25,33);
--doc-heading-border-color: rgb(25,25,33);
--doc-heading-color-alt: rgb(33,33,44);
--md-code-bg-color: rgb(38,38,50);
}

h4.doc-heading {
/* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/
background-color: var(--doc-heading-color);
border: solid var(--doc-heading-border-color);
border-width: 1.5pt;
border-radius: 2pt;
padding: 0pt 5pt 2pt 5pt;
}
h5.doc-heading, h6.heading {
background-color: var(--doc-heading-color-alt);
border-radius: 2pt;
padding: 0pt 5pt 2pt 5pt;
}

/* Make errors in notebooks have scrolling */
.output_error > pre {
overflow: auto;
}
Binary file added docs/_static/favicon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions docs/_static/mathjax.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};

document$.subscribe(() => {
MathJax.typesetPromise()
})
22 changes: 22 additions & 0 deletions docs/api/lora.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# quax.examples.lora

As a (actually quite useful) tech-demo, Quax provides an implementation of [LoRA: Low-Rank Adaptation](https://arxiv.org/abs/2106.09685), which is a popular fine-tuning method for large neural network models.

Most of the time you will just need the [`quax.examples.lora.loraify`][] function, which transforms an existing [Equinox](https://github.com/patrick-kidger/equinox) model.

For a user who only wants to LoRA'ify only part of their model, the underlying [`quax.examples.lora.LoraArray`][] array-ish object (which subclasses [`quax.ArrayValue`][]) is also available.

---

::: quax.examples.lora.loraify

::: quax.examples.lora.LoraArray
selection:
members:
- __init__

## Example

Here's a copy of the LoRA example from the README again:

--8<-- ".lora-example.md"
26 changes: 26 additions & 0 deletions docs/api/quax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# quax

An end user of a library built on Quax needs only one thing from this section: the [`quax.quaxify`][] function.

::: quax.quaxify

---

A developer of a library built on Quax (e.g. if you wanted to write your own libary analogous to `quax.examples.lora`) should additionally know about the following functionality.

!!! Info

See also the [tutorials](../examples/custom_rules.ipynb) for creating your own array-ish Quax types.

::: quax.register

::: quax.Value
selection:
members:
- aval
- default
- materialise

::: quax.ArrayValue
selection:
members: false
Loading

0 comments on commit 8bbef0a

Please sign in to comment.