diff --git a/lineax/_norm.py b/lineax/_norm.py index 165fda8..b46f175 100644 --- a/lineax/_norm.py +++ b/lineax/_norm.py @@ -56,13 +56,19 @@ def sum_squares(x: PyTree[ArrayLike]) -> Scalar: return tree_dot(x, x).real -@jax.custom_jvp def two_norm(x: PyTree[ArrayLike]) -> Scalar: """Computes the L2 norm of a PyTree of arrays. Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes `sqrt(Σ_i x_i^2)` """ + # Wrap the `custom_jvp` into a function so that our autogenerated documentation + # displays the docstring correctly. + return _two_norm(x) + + +@jax.custom_jvp +def _two_norm(x: PyTree[ArrayLike]) -> Scalar: leaves = jtu.tree_leaves(x) size = sum([jnp.size(xi) for xi in leaves]) if size == 1: @@ -76,7 +82,7 @@ def two_norm(x: PyTree[ArrayLike]) -> Scalar: return jnp.sqrt(sum_squares(x)) -@two_norm.defjvp +@_two_norm.defjvp def _two_norm_jvp(x, tx): (x,) = x (tx,) = tx