Skip to content

Commit

Permalink
Norm should return float rather than complex
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl authored and patrick-kidger committed Mar 23, 2024
1 parent 1d0ac6c commit 48eef77
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lineax/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def rms_norm(x: PyTree[ArrayLike]) -> Scalar:
if len(leaves) == 0:
dtype = default_floating_dtype()
else:
dtype = jnp.result_type(*leaves)
dtype = jnp.finfo(jnp.result_type(*leaves)).dtype
return jnp.array(0.0, dtype)
else:
return two_norm(x) / math.sqrt(size)
Expand All @@ -131,7 +131,7 @@ def max_norm(x: PyTree[ArrayLike]) -> Scalar:
if len(leaves) == 0:
dtype = default_floating_dtype()
else:
dtype = jnp.result_type(*leaves)
dtype = jnp.finfo(jnp.result_type(*leaves)).dtype
return jnp.array(0.0, dtype)
else:
out = ft.reduce(jnp.maximum, leaf_maxes)
Expand Down

0 comments on commit 48eef77

Please sign in to comment.