Skip to content

Commit

Permalink
Add stabilization rewrite for log of kv
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 13, 2024
1 parent 133abe8 commit 33a4d48
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ge,
int_div,
isinf,
kve,
le,
log,
log1mexp,
Expand Down Expand Up @@ -3494,3 +3495,18 @@ def local_useless_conj(fgraph, node):
)

register_specialize(local_polygamma_to_tri_gamma)


local_log_kv = PatternNodeRewriter(
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
# During stabilize -x is converted to -1.0 * x
(log, (mul, (kve, "v", "x"), (exp, (mul, -1.0, "x")))),
(sub, (log, (kve, "v", "x")), "x"),
allow_multiple_clients=True,
name="local_log_kv",
# Start the rewrite from the less likely kve node
tracks=[kve],
get_nodes=get_clients_at_depth2,
)

register_stabilize(local_log_kv)
15 changes: 15 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ge,
gt,
int_div,
kv,
le,
log,
log1mexp,
Expand Down Expand Up @@ -4578,3 +4579,17 @@ def test_local_batched_matmul_to_core_matmul():
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)


def test_log_kv_stabilization():
x = pt.scalar("x")
out = log(kv(4.5, x))

# Expression would underflow to -inf without rewrite
mode = get_default_mode().including("stabilize")
# Reference value from mpmath
# mpmath.log(mpmath.besselk(4.5, 1000.0))
np.testing.assert_allclose(
out.eval({x: 1000.0}, mode=mode),
-1003.2180912984705,
)

0 comments on commit 33a4d48

Please sign in to comment.