Skip to content

Commit

Permalink
Fix landing page
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 2, 2024
1 parent 809bcb0 commit 0714f74
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 44 deletions.
44 changes: 39 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Quax
<h1 align="center">Quax</h1>
<h2 align="center">JAX + multiple dispatch + custom array-ish objects</h2>

JAX + multiple dispatch + custom array-ish objects, e.g.:
For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.

Applications include:

- LoRA weight matrices
- symbolic zeros
Expand All @@ -11,8 +14,6 @@ JAX + multiple dispatch + custom array-ish objects, e.g.:
- arrays with physical units attached
- etc! (See the built-in `quax.examples` library for most of the above!)

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.

This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_
Expand All @@ -31,7 +32,40 @@ Available at https://docs.kidger.site/quax.

This example demonstrates everything you need to use the built-in `quax.examples.lora` library.

--8<-- ".lora-example.md"
```python
import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
return model(x)

run(linear, vector) # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)
```

## Work in progress!

Expand Down
47 changes: 8 additions & 39 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Quax

JAX + multiple dispatch + custom array-ish objects, e.g.:
JAX + multiple dispatch + custom array-ish objects.

!!! Example

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.

Applications include:

- LoRA weight matrices
- symbolic zeros
Expand All @@ -11,10 +17,6 @@ JAX + multiple dispatch + custom array-ish objects, e.g.:
- arrays with physical units attached
- etc! (See the built-in `quax.examples` library for most of the above!)

!!! Example

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.

This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_
Expand All @@ -35,40 +37,7 @@ To start writing your own library (with your own array-ish type) using Quax, the

This example demonstrates everything you need to use the built-in `quax.examples.lora` library.

```python
import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
return model(x)

run(linear, vector) # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)
```
--8<-- ".lora-example.md"

## See also: other libraries in the JAX ecosystem

Expand Down

0 comments on commit 0714f74

Please sign in to comment.