diff --git a/README.md b/README.md index 897818e..cd0f762 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ -# Quax +

Quax

+

JAX + multiple dispatch + custom array-ish objects

-JAX + multiple dispatch + custom array-ish objects, e.g.: +For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`. + +Applications include: - LoRA weight matrices - symbolic zeros @@ -11,8 +14,6 @@ JAX + multiple dispatch + custom array-ish objects, e.g.: - arrays with physical units attached - etc! (See the built-in `quax.examples` library for most of the above!) -For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`. - This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects! _(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_ @@ -31,7 +32,40 @@ Available at https://docs.kidger.site/quax. This example demonstrates everything you need to use the built-in `quax.examples.lora` library. ---8<-- ".lora-example.md" +```python +import equinox as eqx +import jax.random as jr +import quax +import quax.examples.lora as lora + +# +# Start off with any JAX program: here, the forward pass through a linear layer. +# + +key1, key2, key3 = jr.split(jr.PRNGKey(0), 3) +linear = eqx.nn.Linear(10, 12, key=key1) +vector = jr.normal(key2, (10,)) + +def run(model, x): + return model(x) + +run(linear, vector) # can call this as normal + +# +# Now let's Lora-ify it. +# + +# Step 1: make the weight be a LoraArray. +lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3) +lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight) +# Step 2: quaxify and call the original function. The transform will call the +# original function, whilst looking up any multiple dispatch rules registered. +# (In this case for doing matmuls against LoraArrays.) +quax.quaxify(run)(lora_linear, vector) +# Appendix: Quax includes a helper to automatically apply Step 1 to all +# `eqx.nn.Linear` layers in a model. +lora_linear = lora.loraify(linear, rank=2, key=key3) +``` ## Work in progress! diff --git a/docs/index.md b/docs/index.md index c847abd..427395e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,12 @@ # Quax -JAX + multiple dispatch + custom array-ish objects, e.g.: +JAX + multiple dispatch + custom array-ish objects. + +!!! Example + + For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`. + +Applications include: - LoRA weight matrices - symbolic zeros @@ -11,10 +17,6 @@ JAX + multiple dispatch + custom array-ish objects, e.g.: - arrays with physical units attached - etc! (See the built-in `quax.examples` library for most of the above!) -!!! Example - - For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`. - This works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects! _(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_ @@ -35,40 +37,7 @@ To start writing your own library (with your own array-ish type) using Quax, the This example demonstrates everything you need to use the built-in `quax.examples.lora` library. -```python -import equinox as eqx -import jax.random as jr -import quax -import quax.examples.lora as lora - -# -# Start off with any JAX program: here, the forward pass through a linear layer. -# - -key1, key2, key3 = jr.split(jr.PRNGKey(0), 3) -linear = eqx.nn.Linear(10, 12, key=key1) -vector = jr.normal(key2, (10,)) - -def run(model, x): - return model(x) - -run(linear, vector) # can call this as normal - -# -# Now let's Lora-ify it. -# - -# Step 1: make the weight be a LoraArray. -lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3) -lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight) -# Step 2: quaxify and call the original function. The transform will call the -# original function, whilst looking up any multiple dispatch rules registered. -# (In this case for doing matmuls against LoraArrays.) -quax.quaxify(run)(lora_linear, vector) -# Appendix: Quax includes a helper to automatically apply Step 1 to all -# `eqx.nn.Linear` layers in a model. -lora_linear = lora.loraify(linear, rank=2, key=key3) -``` +--8<-- ".lora-example.md" ## See also: other libraries in the JAX ecosystem