Skip to content

Commit

Permalink
Doc fix for two_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Aug 18, 2024
1 parent 4a7b108 commit 2e10cce
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions lineax/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,19 @@ def sum_squares(x: PyTree[ArrayLike]) -> Scalar:
return tree_dot(x, x).real


@jax.custom_jvp
def two_norm(x: PyTree[ArrayLike]) -> Scalar:
"""Computes the L2 norm of a PyTree of arrays.
Considering the input `x` as a flat vector `(x_1, ..., x_n)`, then this computes
`sqrt(Σ_i x_i^2)`
"""
# Wrap the `custom_jvp` into a function so that our autogenerated documentation
# displays the docstring correctly.
return _two_norm(x)


@jax.custom_jvp
def _two_norm(x: PyTree[ArrayLike]) -> Scalar:
leaves = jtu.tree_leaves(x)
size = sum([jnp.size(xi) for xi in leaves])
if size == 1:
Expand All @@ -76,7 +82,7 @@ def two_norm(x: PyTree[ArrayLike]) -> Scalar:
return jnp.sqrt(sum_squares(x))


@two_norm.defjvp
@_two_norm.defjvp
def _two_norm_jvp(x, tx):
(x,) = x
(tx,) = tx
Expand Down

0 comments on commit 2e10cce

Please sign in to comment.