From 0c6304efbe6d367d89dcef4fb050466deac0544f Mon Sep 17 00:00:00 2001 From: Jonas Date: Thu, 12 Dec 2024 17:25:07 +0100 Subject: [PATCH] first commit for a decorator that transforms JAX to pytensor --- pyproject.toml | 2 +- pytensor/link/jax/ops.py | 424 +++++++++++++++++++++++++++++++ tests/link/jax/test_as_jax_op.py | 26 ++ 3 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/jax/ops.py create mode 100644 tests/link/jax/test_as_jax_op.py diff --git a/pyproject.toml b/pyproject.toml index 4e2a1fdb05..bf43b19ccb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ tests = [ "pytest-sphinx", ] rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot", "pydot2", "pydot-ng"] -jax = ["jax", "jaxlib"] +jax = ["jax", "jaxlib", "equinox"] numba = ["numba>=0.57", "llvmlite"] [tool.setuptools.packages.find] diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py new file mode 100644 index 0000000000..130ece6eda --- /dev/null +++ b/pytensor/link/jax/ops.py @@ -0,0 +1,424 @@ +"""Convert a jax function to a pytensor compatible function.""" + +import functools as ft +import logging +from collections.abc import Sequence + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten, tree_map, tree_unflatten + +import pytensor.compile.builders +import pytensor.tensor as pt +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Op +from pytensor.link.jax.dispatch import jax_funcify + + +log = logging.getLogger(__name__) + + +def _filter_ptvars(x): + return isinstance(x, pt.Variable) + + +def as_jax_op(jaxfunc, name=None): + """Return a Pytensor from a JAX jittable function. + + This decorator transforms any JAX jittable function into a function that accepts + and returns `pytensor.Variables`. The jax jittable function can accept any + nested python structure (pytrees) as input, and return any nested Python structure. + + It requires to define the output types of the returned values as pytensor types. A + unique name should also be passed in case the name of the jaxfunc is identical to + some other node. The design of this function is based on + https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/ + + Parameters + ---------- + jaxfunc : jax jittable function + function for which the node is created, can return multiple tensors as a tuple. + It is required that all return values are able to transformed to + pytensor.Variable. + name: str + Name of the created pytensor Op, defaults to the name of the passed function. + Only used internally in the pytensor graph. + + Returns + ------- + A function which can be used in a pymc.Model as function, is differentiable + and the resulting model can be compiled either with the default C backend, or + the JAX backend. + + + Notes + ----- + The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, + available at + `pymc-labls.io `__. + To accept functions and non pytensor variables as input, the function make use + of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the + variables. Shapes are inferred using + :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + """ + + def func(*args, **kwargs): + """Return a pytensor from a jax jittable function.""" + ### Split variables: in the ones that will be transformed to JAX inputs, + ### pytensor.Variables; _WrappedFunc, that are functions that have been returned + ### from a transformed function; and the rest, static variables that are not + ### transformed. + + pt_vars, static_vars_tmp = eqx.partition( + (args, kwargs), _filter_ptvars, is_leaf=callable + ) + # is_leaf=callable is used, as libraries like diffrax or equinox might return + # functions that are still seen as a nested pytree structure. We consider them + # as wrappable functions, that will be wrapped with _WrappedFunc. + + func_vars, static_vars = eqx.partition( + static_vars_tmp, lambda x: isinstance(x, _WrappedFunc), is_leaf=callable + ) + vars_from_func = tree_map(lambda x: x.get_vars(), func_vars) + pt_vars = dict(vars=pt_vars, vars_from_func=vars_from_func) + """ + def func_unwrapped(vars_all, static_vars): + vars, vars_from_func = vars_all["vars"], vars_all["vars_from_func"] + func_vars_evaled = tree_map( + lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func + ) + args, kwargs = eqx.combine(vars, static_vars, func_vars_evaled) + return self.jaxfunc(*args, **kwargs) + """ + + pt_vars_flat, vars_treedef = tree_flatten(pt_vars) + pt_vars_types_flat = [var.type for var in pt_vars_flat] + shapes_vars_flat = pytensor.compile.builders.infer_shape(pt_vars_flat, (), ()) + shapes_vars = tree_unflatten(vars_treedef, shapes_vars_flat) + + dummy_inputs_jax = jax.tree_util.tree_map( + lambda var, shape: jnp.empty( + [int(dim.eval()) for dim in shape], dtype=var.type.dtype + ), + pt_vars, + shapes_vars, + ) + + # Combine the static variables with the inputs, and split them again in the + # output. Static variables don't take part in the graph, or might be a + # a function that is returned. + jaxfunc_partitioned, static_out_dic = _partition_jaxfunc( + jaxfunc, static_vars, func_vars + ) + + func_flattened = _flatten_func(jaxfunc_partitioned, vars_treedef) + + jaxtypes_outvars = jax.eval_shape( + ft.partial(jaxfunc_partitioned, vars=dummy_inputs_jax), + ) + + jaxtypes_outvars_flat, outvars_treedef = tree_flatten(jaxtypes_outvars) + + pttypes_outvars = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) + for var in jaxtypes_outvars_flat + ] + + ### Call the function that accepts flat inputs, which in turn calls the one that + ### combines the inputs and static variables. + jitted_sol_op_jax = jax.jit(func_flattened) + len_gz = len(pttypes_outvars) + + vjp_sol_op_jax = _get_vjp_sol_op_jax(func_flattened, len_gz) + jitted_vjp_sol_op_jax = jax.jit(vjp_sol_op_jax) + + if name is None: + curr_name = jaxfunc.__name__ + else: + curr_name = name + + # Get classes that creates a Pytensor Op out of our function that accept + # flattened inputs. They are created each time, to set a custom name for the + # class. + SolOp, VJPSolOp = _return_pytensor_ops_classes(curr_name) + + local_op = SolOp( + vars_treedef, + outvars_treedef, + input_types=pt_vars_types_flat, + output_types=pttypes_outvars, + jitted_sol_op_jax=jitted_sol_op_jax, + jitted_vjp_sol_op_jax=jitted_vjp_sol_op_jax, + ) + + @jax_funcify.register(SolOp) + def sol_op_jax_funcify(op, **kwargs): + return local_op.perform_jax + + @jax_funcify.register(VJPSolOp) + def vjp_sol_op_jax_funcify(op, **kwargs): + return local_op.vjp_sol_op.perform_jax + + ### Evaluate the Pytensor Op and return unflattened results + output_flat = local_op(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] # tree_unflatten expects a sequence. + outvars = tree_unflatten(outvars_treedef, output_flat) + + static_outfuncs, static_outvars = eqx.partition( + static_out_dic["out"], callable, is_leaf=callable + ) + + static_outfuncs_flat, treedef_outfuncs = jax.tree_util.tree_flatten( + static_outfuncs, is_leaf=callable + ) + for i_func, _ in enumerate(static_outfuncs_flat): + static_outfuncs_flat[i_func] = _WrappedFunc( + jaxfunc, i_func, *args, **kwargs + ) + + static_outfuncs = jax.tree_util.tree_unflatten( + treedef_outfuncs, static_outfuncs_flat + ) + static_vars = eqx.combine(static_outfuncs, static_outvars, is_leaf=callable) + + output = eqx.combine(outvars, static_vars, is_leaf=callable) + + return output + + return func + + +class _WrappedFunc: + def __init__(self, exterior_func, i_func, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.i_func = i_func + vars, static_vars = eqx.partition( + (self.args, self.kwargs), _filter_ptvars, is_leaf=callable + ) + self.vars = vars + self.static_vars = static_vars + self.exterior_func = exterior_func + + def __call__(self, *args, **kwargs): + # If called, assume that args and kwargs are pytensors, so return the result + # as pytensors. + def f(func, *args, **kwargs): + res = func(*args, **kwargs) + return res + + return as_jax_op(f)(self, *args, **kwargs) + + def get_vars(self): + return self.vars + + def get_func_with_vars(self, vars): + # Use other variables than the saved ones, to generate the function. This + # is used to transform vars externally from pytensor to JAX, and use the + # then create the function which is returned. + + args, kwargs = eqx.combine(vars, self.static_vars, is_leaf=callable) + output = self.exterior_func(*args, **kwargs) + outfuncs, _ = eqx.partition(output, callable, is_leaf=callable) + outfuncs_flat, _ = jax.tree_util.tree_flatten(outfuncs, is_leaf=callable) + interior_func = outfuncs_flat[self.i_func] + return interior_func + + +def _get_vjp_sol_op_jax(jaxfunc, len_gz): + def vjp_sol_op_jax(args): + y0 = args[:-len_gz] + gz = args[-len_gz:] + if len(gz) == 1: + gz = gz[0] + + def func(*inputs): + return jaxfunc(inputs) + + primals, vjp_fn = jax.vjp(func, *y0) + gz = tree_map( + lambda g, primal: jnp.broadcast_to(g, jnp.shape(primal)), + gz, + primals, + ) + if len(y0) == 1: + return vjp_fn(gz)[0] + else: + return tuple(vjp_fn(gz)) + + return vjp_sol_op_jax + + +def _partition_jaxfunc(jaxfunc, static_vars, func_vars): + """Partition the jax function into static and non-static variables. + + Returns a function that accepts only non-static variables and returns the non-static + variables. The returned static variables are stored in a dictionary and returned, + to allow the referencing after creating the function + + Additionally wrapped functions saved in func_vars are regenerated with + vars["vars_from_func"] as input, to allow the transformation of the variables. + """ + static_out_dic = {"out": None} + + def jaxfunc_partitioned(vars): + vars, vars_from_func = vars["vars"], vars["vars_from_func"] + func_vars_evaled = tree_map( + lambda x, y: x.get_func_with_vars(y), func_vars, vars_from_func + ) + args, kwargs = eqx.combine( + vars, static_vars, func_vars_evaled, is_leaf=callable + ) + + out = jaxfunc(*args, **kwargs) + outvars, static_out = eqx.partition(out, eqx.is_array, is_leaf=callable) + static_out_dic["out"] = static_out + return outvars + + return jaxfunc_partitioned, static_out_dic + + +### Construct the function that accepts flat inputs and returns flat outputs. +def _flatten_func(jaxfunc, vars_treedef): + def func_flattened(vars_flat): + vars = tree_unflatten(vars_treedef, vars_flat) + outvars = jaxfunc(vars) + outvars_flat, _ = tree_flatten(outvars) + return _normalize_flat_output(outvars_flat) + + return func_flattened + + +def _normalize_flat_output(output): + if len(output) > 1: + return tuple( + output + ) # Transform to tuple because jax makes a difference between + # tuple and list and not pytensor + else: + return output[0] + + +def _return_pytensor_ops_classes(name): + class SolOp(Op): + def __init__( + self, + input_treedef, + output_treeedef, + input_types, + output_types, + jitted_sol_op_jax, + jitted_vjp_sol_op_jax, + ): + self.vjp_sol_op = None + self.input_treedef = input_treedef + self.output_treedef = output_treeedef + self.input_types = input_types + self.output_types = output_types + self.jitted_sol_op_jax = jitted_sol_op_jax + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, *inputs): + self.num_inputs = len(inputs) + + # Define our output variables + outputs = [pt.as_tensor_variable(type()) for type in self.output_types] + self.num_outputs = len(outputs) + + self.vjp_sol_op = VJPSolOp( + self.input_treedef, + self.input_types, + self.jitted_vjp_sol_op_jax, + ) + + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_sol_op_jax(inputs) + if self.num_outputs > 1: + for i in range(self.num_outputs): + outputs[i][0] = np.array(results[i], self.output_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.output_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_sol_op_jax(inputs) + return results + + def grad(self, inputs, output_gradients): + # If a output is not used, it is disconnected and doesn't have a gradient. + # Set gradient here to zero for those outputs. + for i in range(self.num_outputs): + if isinstance(output_gradients[i].type, DisconnectedType): + if None not in self.output_types[i].shape: + output_gradients[i] = pt.zeros( + self.output_types[i].shape, self.output_types[i].dtype + ) + else: + output_gradients[i] = pt.zeros((), self.output_types[i].dtype) + result = self.vjp_sol_op(inputs, output_gradients) + + if self.num_inputs > 1: + return result + else: + return (result,) # Pytensor requires a tuple here + + # vector-jacobian product Op + class VJPSolOp(Op): + def __init__( + self, + input_treedef, + input_types, + jitted_vjp_sol_op_jax, + ): + self.input_treedef = input_treedef + self.input_types = input_types + self.jitted_vjp_sol_op_jax = jitted_vjp_sol_op_jax + + def make_node(self, y0, gz): + y0 = [ + pt.as_tensor_variable( + _y, + ).astype(self.input_types[i].dtype) + for i, _y in enumerate(y0) + ] + gz_not_disconntected = [ + pt.as_tensor_variable(_gz) + for _gz in gz + if not isinstance(_gz.type, DisconnectedType) + ] + outputs = [in_type() for in_type in self.input_types] + self.num_outputs = len(outputs) + return Apply(self, y0 + gz_not_disconntected, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if len(self.input_types) > 1: + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.input_types[i].dtype) + else: + outputs[0][0] = np.array(results, self.input_types[0].dtype) + + def perform_jax(self, *inputs): + results = self.jitted_vjp_sol_op_jax(tuple(inputs)) + if self.num_outputs == 1: + if isinstance(results, Sequence): + return results[0] + else: + return results + else: + return tuple(results) + + SolOp.__name__ = name + SolOp.__qualname__ = ".".join(SolOp.__qualname__.split(".")[:-1] + [name]) + + VJPSolOp.__name__ = "VJP_" + name + VJPSolOp.__qualname__ = ".".join( + VJPSolOp.__qualname__.split(".")[:-1] + ["VJP_" + name] + ) + + return SolOp, VJPSolOp diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py new file mode 100644 index 0000000000..6feb1124a2 --- /dev/null +++ b/tests/link/jax/test_as_jax_op.py @@ -0,0 +1,26 @@ +import jax +import numpy as np + +from pytensor import config +from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import as_jax_op +from pytensor.tensor import tensor +from tests.link.jax.test_basic import compare_jax_and_py + +def test_as_jax_op1(): + # 2 parameters input, single output + rng = np.random.default_rng(14) + x = tensor("a", shape=(2,)) + y = tensor("b", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x + y) + + out = f(x, y) + + fg = FunctionGraph([x, y], [out]) + fn, _ = compare_jax_and_py(fg, test_values)