diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 71ea40de0f..64a077ae94 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -31,6 +31,7 @@ GammaIncInv, Iv, Ive, + Kv, Log1mexp, Psi, TriGamma, @@ -293,6 +294,16 @@ def jax_funcify_Ive(op, **kwargs): return ive +@jax_funcify.register(Kv) +def jax_funcify_Kv(op, **kwargs): + kve = try_import_tfp_jax_op(op, jax_op_name="bessel_kve") + + def kv(v, x): + return kve(v, x) / jnp.exp(jnp.abs(x)) + + return kv + + @jax_funcify.register(Log1mexp) def jax_funcify_Log1mexp(op, node, **kwargs): def log1mexp(x): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index e3379492fa..bdd328510f 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1281,6 +1281,36 @@ def c_code(self, *args, **kwargs): ive = Ive(upgrade_to_float, name="ive") +class Kv(BinaryScalarOp): + """Modified Bessel function of the second kind of real order v.""" + + nfunc_spec = ("scipy.special.kv", 2, 1) + + @staticmethod + def st_impl(v, x): + return scipy.special.kv(v, x) + + def impl(self, v, x): + return self.st_impl(v, x) + + def L_op(self, inputs, outputs, output_grads): + v, x = inputs + [out] = outputs + [g_out] = output_grads + # -(v / x) * kv(v, x) - kv(v - 1, x) + dx = -(v / x) * out - self(v - 1, x) + return [ + grad_not_implemented(self, 0, v), + g_out * dx, + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +kv = Kv(upgrade_to_float, name="kv") + + class Sigmoid(UnaryScalarOp): """ Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index d1e4dc6195..2d297f82e0 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1229,6 +1229,11 @@ def ive(v, x): """Exponentially scaled modified Bessel function of the first kind of order v (real).""" +@scalar_elemwise +def kv(v, x): + """Modified Bessel function of the second kind of real order v.""" + + @scalar_elemwise def sigmoid(x): """Logistic sigmoid function (1 / (1 + exp(-x)), also known as expit or inverse logit""" @@ -3040,6 +3045,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "i1", "iv", "ive", + "kv", "sigmoid", "expit", "softplus", diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 0469301791..e3a377f69b 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -21,6 +21,7 @@ gammainccinv, gammaincinv, iv, + kv, log, log1mexp, polygamma, @@ -153,11 +154,7 @@ def test_erfinv(): @pytest.mark.parametrize( "op, test_values", - [ - (erfcx, (0.7,)), - (erfcinv, (0.7,)), - (iv, (0.3, 0.7)), - ], + [(erfcx, (0.7,)), (erfcinv, (0.7,)), (iv, (0.3, 0.7)), (kv, (-2.5, 2.0))], ) @pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability") def test_tfp_ops(op, test_values): diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 6ca9279bca..8b638f84b7 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pytensor.gradient import verify_grad +from pytensor.gradient import NullTypeGradError, verify_grad from pytensor.scalar import ScalarLoop from pytensor.tensor.elemwise import Elemwise @@ -18,7 +18,7 @@ from pytensor import tensor as pt from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config -from pytensor.tensor import gammaincc, inplace, vector +from pytensor.tensor import gammaincc, inplace, kv, vector from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, @@ -1196,3 +1196,23 @@ def test_unused_grad_loop_opt(self, wrt): [dd for i, dd in enumerate(expected_dds) if i in wrt], rtol=rtol, ) + + +def test_kv(): + rng = np.random.default_rng(3772) + v = vector("v") + x = vector("x") + + out = kv(v[:, None], x[None, :]) + test_v = np.array([-3.7, 4, 4.5, 5], dtype=v.type.dtype) + test_x = np.linspace(0, 5, 10, dtype=x.type.dtype) + + np.testing.assert_allclose( + out.eval({v: test_v, x: test_x}), + scipy.special.kv(test_v[:, None], test_x[None, :]), + ) + + with pytest.raises(NullTypeGradError): + grad(out.sum(), v) + + verify_grad(lambda x: kv(4.5, x), [test_x + 0.5], rng=rng)