Skip to content

francois-rozet/inox

Repository files navigation

Inox's banner

Stainless neural networks in JAX

Inox is a minimal JAX library for neural networks with an intuitive PyTorch-like syntax. As with Equinox, modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.

Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like NNX and Serket to provide a versatile interface. Despite the differences, Inox remains compatible with the Equinox ecosystem, and its components (modules, transformations, ...) are for the most part interchangeable with those of Equinox.

Inox means "stainless steel" in French 🔪

Installation

The inox package is available on PyPI, which means it is installable via pip.

pip install inox

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/inox

Getting started

Modules are defined with an intuitive PyTorch-like syntax,

import jax
import inox.nn as nn

init_key, data_key = jax.random.split(jax.random.key(0))

class MLP(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(3, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, 3, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        return x

model = MLP(init_key)

and are compatible with JAX transformations.

X = jax.random.normal(data_key, (1024, 3))
Y = jax.numpy.sort(X, axis=-1)

@jax.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(model, X, Y)

However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.

model.name = 'stainless'  # not an array

@inox.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = inox.grad(loss_fn)(model, X, Y)

Inox also provides a partition mechanism to split the static definition of a module (structure, strings, flags, ...) from its dynamic content (parameters, indices, statistics, ...), which is convenient for updating parameters.

model.mask = jax.numpy.array([1, 0, 1])  # not a parameter

static, params, others = model.partition(nn.Parameter)

@jax.jit
def loss_fn(params, others, x, y):
    model = static(arrays, others)
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(params, others, X, Y)
params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)

model = static(params, others)

For more information, check out the documentation and tutorials at inox.readthedocs.io.

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.