Skip to content

Commit

Permalink
Implement Kv Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 11, 2024
1 parent fdbf3aa commit 0d88da3
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 7 deletions.
11 changes: 11 additions & 0 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GammaIncInv,
Iv,
Ive,
Kv,
Log1mexp,
Psi,
TriGamma,
Expand Down Expand Up @@ -293,6 +294,16 @@ def jax_funcify_Ive(op, **kwargs):
return ive


@jax_funcify.register(Kv)
def jax_funcify_Kv(op, **kwargs):
kve = try_import_tfp_jax_op(op, jax_op_name="bessel_kve")

def kv(v, x):
return kve(v, x) / jnp.exp(jnp.abs(x))

return kv


@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
Expand Down
30 changes: 30 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,36 @@ def c_code(self, *args, **kwargs):
ive = Ive(upgrade_to_float, name="ive")


class Kv(BinaryScalarOp):
"""Modified Bessel function of the second kind of real order v."""

nfunc_spec = ("scipy.special.kv", 2, 1)

@staticmethod
def st_impl(v, x):
return scipy.special.kv(v, x)

Check warning on line 1291 in pytensor/scalar/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/math.py#L1291

Added line #L1291 was not covered by tests

def impl(self, v, x):
return self.st_impl(v, x)

Check warning on line 1294 in pytensor/scalar/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/scalar/math.py#L1294

Added line #L1294 was not covered by tests

def L_op(self, inputs, outputs, output_grads):
v, x = inputs
[out] = outputs
[g_out] = output_grads
# -(v / x) * kv(v, x) - kv(v - 1, x)
dx = -(v / x) * out - self(v - 1, x)
return [
grad_not_implemented(self, 0, v),
g_out * dx,
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


kv = Kv(upgrade_to_float, name="kv")


class Sigmoid(UnaryScalarOp):
"""
Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit
Expand Down
6 changes: 6 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,11 @@ def ive(v, x):
"""Exponentially scaled modified Bessel function of the first kind of order v (real)."""


@scalar_elemwise
def kv(v, x):
"""Modified Bessel function of the second kind of real order v."""


@scalar_elemwise
def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit"""
Expand Down Expand Up @@ -3040,6 +3045,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"i1",
"iv",
"ive",
"kv",
"sigmoid",
"expit",
"softplus",
Expand Down
7 changes: 2 additions & 5 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
gammainccinv,
gammaincinv,
iv,
kv,
log,
log1mexp,
polygamma,
Expand Down Expand Up @@ -153,11 +154,7 @@ def test_erfinv():

@pytest.mark.parametrize(
"op, test_values",
[
(erfcx, (0.7,)),
(erfcinv, (0.7,)),
(iv, (0.3, 0.7)),
],
[(erfcx, (0.7,)), (erfcinv, (0.7,)), (iv, (0.3, 0.7)), (kv, (-2.5, 2.0))],
)
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
def test_tfp_ops(op, test_values):
Expand Down
24 changes: 22 additions & 2 deletions tests/tensor/test_math_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from pytensor.gradient import verify_grad
from pytensor.gradient import NullTypeGradError, verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import Elemwise

Expand All @@ -18,7 +18,7 @@
from pytensor import tensor as pt
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor import gammaincc, inplace, vector
from pytensor.tensor import gammaincc, inplace, kv, vector
from tests import unittest_tools as utt
from tests.tensor.utils import (
_good_broadcast_unary_chi2sf,
Expand Down Expand Up @@ -1196,3 +1196,23 @@ def test_unused_grad_loop_opt(self, wrt):
[dd for i, dd in enumerate(expected_dds) if i in wrt],
rtol=rtol,
)


def test_kv():
rng = np.random.default_rng(3772)
v = vector("v")
x = vector("x")

out = kv(v[:, None], x[None, :])
test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype)
test_x = np.linspace(0, 5, 10, dtype=x.type.dtype)

np.testing.assert_allclose(
out.eval({v: test_v, x: test_x}),
scipy.special.kv(test_v[:, None], test_x[None, :]),
)

with pytest.raises(NullTypeGradError):
grad(out.sum(), v)

verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng)

0 comments on commit 0d88da3

Please sign in to comment.