diff --git a/README.md b/README.md
index 897818e..cd0f762 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,9 @@
-# Quax
+
Quax
+JAX + multiple dispatch + custom array-ish objects
-JAX + multiple dispatch + custom array-ish objects, e.g.:
+For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.
+
+Applications include:
- LoRA weight matrices
- symbolic zeros
@@ -11,8 +14,6 @@ JAX + multiple dispatch + custom array-ish objects, e.g.:
- arrays with physical units attached
- etc! (See the built-in `quax.examples` library for most of the above!)
-For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.
-
This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!
_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_
@@ -31,7 +32,40 @@ Available at https://docs.kidger.site/quax.
This example demonstrates everything you need to use the built-in `quax.examples.lora` library.
---8<-- ".lora-example.md"
+```python
+import equinox as eqx
+import jax.random as jr
+import quax
+import quax.examples.lora as lora
+
+#
+# Start off with any JAX program: here, the forward pass through a linear layer.
+#
+
+key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
+linear = eqx.nn.Linear(10, 12, key=key1)
+vector = jr.normal(key2, (10,))
+
+def run(model, x):
+ return model(x)
+
+run(linear, vector) # can call this as normal
+
+#
+# Now let's Lora-ify it.
+#
+
+# Step 1: make the weight be a LoraArray.
+lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
+lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
+# Step 2: quaxify and call the original function. The transform will call the
+# original function, whilst looking up any multiple dispatch rules registered.
+# (In this case for doing matmuls against LoraArrays.)
+quax.quaxify(run)(lora_linear, vector)
+# Appendix: Quax includes a helper to automatically apply Step 1 to all
+# `eqx.nn.Linear` layers in a model.
+lora_linear = lora.loraify(linear, rank=2, key=key3)
+```
## Work in progress!
diff --git a/docs/index.md b/docs/index.md
index c847abd..427395e 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,6 +1,12 @@
# Quax
-JAX + multiple dispatch + custom array-ish objects, e.g.:
+JAX + multiple dispatch + custom array-ish objects.
+
+!!! Example
+
+ For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.
+
+Applications include:
- LoRA weight matrices
- symbolic zeros
@@ -11,10 +17,6 @@ JAX + multiple dispatch + custom array-ish objects, e.g.:
- arrays with physical units attached
- etc! (See the built-in `quax.examples` library for most of the above!)
-!!! Example
-
- For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.
-
This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!
_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_
@@ -35,40 +37,7 @@ To start writing your own library (with your own array-ish type) using Quax, the
This example demonstrates everything you need to use the built-in `quax.examples.lora` library.
-```python
-import equinox as eqx
-import jax.random as jr
-import quax
-import quax.examples.lora as lora
-
-#
-# Start off with any JAX program: here, the forward pass through a linear layer.
-#
-
-key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
-linear = eqx.nn.Linear(10, 12, key=key1)
-vector = jr.normal(key2, (10,))
-
-def run(model, x):
- return model(x)
-
-run(linear, vector) # can call this as normal
-
-#
-# Now let's Lora-ify it.
-#
-
-# Step 1: make the weight be a LoraArray.
-lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
-lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
-# Step 2: quaxify and call the original function. The transform will call the
-# original function, whilst looking up any multiple dispatch rules registered.
-# (In this case for doing matmuls against LoraArrays.)
-quax.quaxify(run)(lora_linear, vector)
-# Appendix: Quax includes a helper to automatically apply Step 1 to all
-# `eqx.nn.Linear` layers in a model.
-lora_linear = lora.loraify(linear, rank=2, key=key3)
-```
+--8<-- ".lora-example.md"
## See also: other libraries in the JAX ecosystem