From 47b6ce6bd90a1b626017102cad8b15f5dc3f705b Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Tue, 14 May 2024 02:42:31 +0200 Subject: [PATCH] [math] Support AD for `TMath::LnGamma` using functions from GSL This is to avoid any num-diff fallback in RooFit, which results in annoying warnings for the user. A new function `ROOT::Math::digamma` is added to the public interface, which wraps `gsl_sf_psi`. The digamma function is the derivative of `lgamma`, so it is used in `CladDerivator.h` to define the derivatives of `TMath::LnGamma` and the related gamma funcitons that are used to define Poisson cdfs. --- math/mathcore/inc/Math/CladDerivator.h | 502 ++++++++++++++++++ math/mathcore/inc/Math/PdfFuncMathCore.h | 6 +- math/mathmore/inc/Math/SpecFuncMathMore.h | 1 + math/mathmore/src/SpecFuncMathMore.cxx | 8 + .../roofitcore/inc/RooFit/Detail/MathFuncs.h | 8 +- 5 files changed, 518 insertions(+), 7 deletions(-) diff --git a/math/mathcore/inc/Math/CladDerivator.h b/math/mathcore/inc/Math/CladDerivator.h index 78cdf87c8e082..42bc0a8ab66be 100644 --- a/math/mathcore/inc/Math/CladDerivator.h +++ b/math/mathcore/inc/Math/CladDerivator.h @@ -24,6 +24,14 @@ #include #include "TMath.h" + +// For the digamma function, that is the derivative of lgamma. We get it via +// mathmore from the GSL, so the pullbacks that use digamma are only available +// with mathmore=ON. +#ifdef R__HAS_MATHMORE +#include "Math/SpecFuncMathMore.h" +#endif + namespace clad { namespace custom_derivatives { namespace TMath { @@ -93,6 +101,16 @@ ValueAndPushforward Erfc_pushforward(T x, T d_x) return {::TMath::Erfc(x), -Erf_pushforward(x, d_x).pushforward}; } +#ifdef R__HAS_MATHMORE + +template +ValueAndPushforward LnGamma_pushforward(T z, T d_z) +{ + return {::TMath::LnGamma(z), ::ROOT::Math::digamma(z) * d_z}; +} + +#endif + template ValueAndPushforward Exp_pushforward(T x, T d_x) { @@ -659,6 +677,490 @@ inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, do *d_xi += _d_v * -((x - x0) / (xi * xi)); } +#ifdef R__HAS_MATHMORE + +inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x); + +inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x) +{ + // Synced with SpecFuncCephes.h + constexpr double kMACHEP = 1.11022302462515654042363166809e-16; + constexpr double kMAXLOG = 709.782712893383973096206318587; + constexpr double kMINLOG = -708.396418532264078748994506896; + constexpr double kMAXSTIR = 108.116855767857671821730036754; + constexpr double kMAXLGM = 2.556348e305; + constexpr double kBig = 4.503599627370496e15; + constexpr double kBiginv = 2.22044604925031308085e-16; + + double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0; + bool _cond0; + bool _cond1; + bool _cond2; + double _t0; + double _t1; + bool _cond3; + double _t2; + double _t3; + double _t4; + double _t5; + unsigned long _t6; + clad::tape _t7 = {}; + clad::tape _t8 = {}; + clad::tape _t9 = {}; + double ans, ax, c, r; + _cond0 = a <= 0; + if (_cond0) + return; + _cond1 = x <= 0; + if (_cond1) + return; + _cond2 = (x > 1.) && (x > a); + if (_cond2) { + double _r0 = 0; + double _r1 = 0; + inc_gamma_c_pullback(a, x, -_d_y, &_r0, &_r1); + *_d_a += _r0; + *_d_x += _r1; + return; + } + _t0 = ax; + _t1 = ::std::log(x); + ax = a * _t1 - x - ::std::lgamma(a); + _cond3 = ax < -kMAXLOG; + if (_cond3) + goto _label3; + _t2 = ax; + ax = ::std::exp(ax); + _t3 = r; + r = a; + _t4 = c; + c = 1.; + _t5 = ans; + ans = 1.; + _t6 = 0; + do { + _t6++; + clad::push(_t7, r); + r += 1.; + clad::push(_t8, c); + c *= x / r; + clad::push(_t9, ans); + ans += c; + } while (c / ans > kMACHEP); + { + _d_ans += _d_y / a * ax; + _d_ax += ans * _d_y / a; + double _r6 = _d_y * -(ans * ax / (a * a)); + *_d_a += _r6; + } + do { + { + { + ans = clad::pop(_t9); + double _r_d7 = _d_ans; + _d_c += _r_d7; + } + { + c = clad::pop(_t8); + double _r_d6 = _d_c; + _d_c -= _r_d6; + _d_c += _r_d6 * x / r; + *_d_x += c * _r_d6 / r; + double _r5 = c * _r_d6 * -(x / (r * r)); + _d_r += _r5; + } + { + r = clad::pop(_t7); + double _r_d5 = _d_r; + } + } + _t6--; + } while (_t6); + { + ans = _t5; + double _r_d4 = _d_ans; + _d_ans -= _r_d4; + } + { + c = _t4; + double _r_d3 = _d_c; + _d_c -= _r_d3; + } + { + r = _t3; + double _r_d2 = _d_r; + _d_r -= _r_d2; + *_d_a += _r_d2; + } + { + ax = _t2; + double _r_d1 = _d_ax; + _d_ax -= _r_d1; + double _r4 = 0; + _r4 += _r_d1 * clad::custom_derivatives::exp_pushforward(ax, 1.).pushforward; + _d_ax += _r4; + } + if (_cond3) + _label3:; + { + ax = _t0; + double _r_d0 = _d_ax; + _d_ax -= _r_d0; + *_d_a += _r_d0 * _t1; + double _r2 = 0; + _r2 += a * _r_d0 * clad::custom_derivatives::log_pushforward(x, 1.).pushforward; + *_d_x += _r2; + *_d_x += -_r_d0; + double _r3 = 0; + _r3 += -_r_d0 * ::ROOT::Math::digamma(a); //numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a); + *_d_a += _r3; + } +} + +inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x) +{ + // Synced with SpecFuncCephes.h + constexpr double kMACHEP = 1.11022302462515654042363166809e-16; + constexpr double kMAXLOG = 709.782712893383973096206318587; + constexpr double kMINLOG = -708.396418532264078748994506896; + constexpr double kMAXSTIR = 108.116855767857671821730036754; + constexpr double kMAXLGM = 2.556348e305; + constexpr double kBig = 4.503599627370496e15; + constexpr double kBiginv = 2.22044604925031308085e-16; + + double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_yc = 0, _d_r = 0, _d_t = 0, _d_y0 = 0, _d_z = 0; + double _d_pk = 0, _d_pkm1 = 0, _d_pkm2 = 0, _d_qk = 0, _d_qkm1 = 0, _d_qkm2 = 0; + bool _cond0; + bool _cond1; + bool _cond2; + double _t0; + double _t1; + bool _cond3; + double _t2; + double _t3; + double _t4; + double _t5; + double _t6; + double _t7; + double _t8; + double _t9; + double _t10; + unsigned long _t11; + clad::tape _t12 = {}; + clad::tape _t13 = {}; + clad::tape _t14 = {}; + clad::tape _t15 = {}; + clad::tape _t16 = {}; + clad::tape _t17 = {}; + clad::tape _t19 = {}; + clad::tape _t20 = {}; + clad::tape _t21 = {}; + clad::tape _t22 = {}; + clad::tape _t23 = {}; + clad::tape _t24 = {}; + clad::tape _t25 = {}; + clad::tape _t26 = {}; + clad::tape _t27 = {}; + clad::tape _t29 = {}; + clad::tape _t30 = {}; + clad::tape _t31 = {}; + clad::tape _t32 = {}; + clad::tape _t33 = {}; + double ans, ax, c, yc, r, t, y, z; + double pk, pkm1, pkm2, qk, qkm1, qkm2; + _cond0 = a <= 0; + if (_cond0) + return; + _cond1 = x <= 0; + if (_cond1) + return; + _cond2 = (x < 1.) || (x < a); + if (_cond2) { + double _r0 = 0; + double _r1 = 0; + inc_gamma_pullback(a, x, -_d_y, &_r0, &_r1); + *_d_a += _r0; + *_d_x += _r1; + return; + } + _t0 = ax; + _t1 = ::std::log(x); + ax = a * _t1 - x - ::std::lgamma(a); + _cond3 = ax < -kMAXLOG; + if (_cond3) + goto _label3; + _t2 = ax; + ax = ::std::exp(ax); + _t3 = y; + y = 1. - a; + _t4 = z; + z = x + y + 1.; + _t5 = c; + c = 0.; + _t6 = pkm2; + pkm2 = 1.; + _t7 = qkm2; + qkm2 = x; + _t8 = pkm1; + pkm1 = x + 1.; + _t9 = qkm1; + qkm1 = z * x; + _t10 = ans; + ans = pkm1 / qkm1; + _t11 = 0; + do { + _t11++; + clad::push(_t12, c); + c += 1.; + clad::push(_t13, y); + y += 1.; + clad::push(_t14, z); + z += 2.; + clad::push(_t15, yc); + yc = y * c; + clad::push(_t16, pk); + pk = pkm1 * z - pkm2 * yc; + clad::push(_t17, qk); + qk = qkm1 * z - qkm2 * yc; + double _t18 = qk; + { + if (_t18) { + clad::push(_t20, r); + r = pk / qk; + clad::push(_t21, t); + t = ::std::abs((ans - r) / r); + clad::push(_t22, ans); + ans = r; + } else { + clad::push(_t23, t); + t = 1.; + } + clad::push(_t19, _t18); + } + clad::push(_t24, pkm2); + pkm2 = pkm1; + clad::push(_t25, pkm1); + pkm1 = pk; + clad::push(_t26, qkm2); + qkm2 = qkm1; + clad::push(_t27, qkm1); + qkm1 = qk; + bool _t28 = ::std::abs(pk) > kBig; + { + if (_t28) { + clad::push(_t30, pkm2); + pkm2 *= kBiginv; + clad::push(_t31, pkm1); + pkm1 *= kBiginv; + clad::push(_t32, qkm2); + qkm2 *= kBiginv; + clad::push(_t33, qkm1); + qkm1 *= kBiginv; + } + clad::push(_t29, _t28); + } + } while (t > kMACHEP); + { + _d_ans += _d_y * ax; + _d_ax += ans * _d_y; + } + do { + { + if (clad::pop(_t29)) { + { + qkm1 = clad::pop(_t33); + double _r_d27 = _d_qkm1; + _d_qkm1 -= _r_d27; + _d_qkm1 += _r_d27 * kBiginv; + } + { + qkm2 = clad::pop(_t32); + double _r_d26 = _d_qkm2; + _d_qkm2 -= _r_d26; + _d_qkm2 += _r_d26 * kBiginv; + } + { + pkm1 = clad::pop(_t31); + double _r_d25 = _d_pkm1; + _d_pkm1 -= _r_d25; + _d_pkm1 += _r_d25 * kBiginv; + } + { + pkm2 = clad::pop(_t30); + double _r_d24 = _d_pkm2; + _d_pkm2 -= _r_d24; + _d_pkm2 += _r_d24 * kBiginv; + } + } + { + qkm1 = clad::pop(_t27); + double _r_d23 = _d_qkm1; + _d_qkm1 -= _r_d23; + _d_qk += _r_d23; + } + { + qkm2 = clad::pop(_t26); + double _r_d22 = _d_qkm2; + _d_qkm2 -= _r_d22; + _d_qkm1 += _r_d22; + } + { + pkm1 = clad::pop(_t25); + double _r_d21 = _d_pkm1; + _d_pkm1 -= _r_d21; + _d_pk += _r_d21; + } + { + pkm2 = clad::pop(_t24); + double _r_d20 = _d_pkm2; + _d_pkm2 -= _r_d20; + _d_pkm1 += _r_d20; + } + if (clad::pop(_t19)) { + { + ans = clad::pop(_t22); + double _r_d18 = _d_ans; + _d_ans -= _r_d18; + _d_r += _r_d18; + } + { + t = clad::pop(_t21); + double _r_d17 = _d_t; + _d_t -= _r_d17; + double _r7 = 0; + _r7 += _r_d17 * clad::custom_derivatives::std::abs_pushforward((ans - r) / r, 1.).pushforward; + _d_ans += _r7 / r; + _d_r += -_r7 / r; + double _r8 = _r7 * -((ans - r) / (r * r)); + _d_r += _r8; + } + { + r = clad::pop(_t20); + double _r_d16 = _d_r; + _d_r -= _r_d16; + _d_pk += _r_d16 / qk; + double _r6 = _r_d16 * -(pk / (qk * qk)); + _d_qk += _r6; + } + } else { + t = clad::pop(_t23); + double _r_d19 = _d_t; + _d_t -= _r_d19; + } + { + qk = clad::pop(_t17); + double _r_d15 = _d_qk; + _d_qk -= _r_d15; + _d_qkm1 += _r_d15 * z; + _d_z += qkm1 * _r_d15; + _d_qkm2 += -_r_d15 * yc; + _d_yc += qkm2 * -_r_d15; + } + { + pk = clad::pop(_t16); + double _r_d14 = _d_pk; + _d_pk -= _r_d14; + _d_pkm1 += _r_d14 * z; + _d_z += pkm1 * _r_d14; + _d_pkm2 += -_r_d14 * yc; + _d_yc += pkm2 * -_r_d14; + } + { + yc = clad::pop(_t15); + double _r_d13 = _d_yc; + _d_yc -= _r_d13; + _d_y0 += _r_d13 * c; + _d_c += y * _r_d13; + } + { + z = clad::pop(_t14); + double _r_d12 = _d_z; + } + { + y = clad::pop(_t13); + double _r_d11 = _d_y0; + } + { + c = clad::pop(_t12); + double _r_d10 = _d_c; + } + } + _t11--; + } while (_t11); + { + ans = _t10; + double _r_d9 = _d_ans; + _d_ans -= _r_d9; + _d_pkm1 += _r_d9 / qkm1; + double _r5 = _r_d9 * -(pkm1 / (qkm1 * qkm1)); + _d_qkm1 += _r5; + } + { + qkm1 = _t9; + double _r_d8 = _d_qkm1; + _d_qkm1 -= _r_d8; + _d_z += _r_d8 * x; + *_d_x += z * _r_d8; + } + { + pkm1 = _t8; + double _r_d7 = _d_pkm1; + _d_pkm1 -= _r_d7; + *_d_x += _r_d7; + } + { + qkm2 = _t7; + double _r_d6 = _d_qkm2; + _d_qkm2 -= _r_d6; + *_d_x += _r_d6; + } + { + pkm2 = _t6; + double _r_d5 = _d_pkm2; + _d_pkm2 -= _r_d5; + } + { + c = _t5; + double _r_d4 = _d_c; + _d_c -= _r_d4; + } + { + z = _t4; + double _r_d3 = _d_z; + _d_z -= _r_d3; + *_d_x += _r_d3; + _d_y0 += _r_d3; + } + { + y = _t3; + double _r_d2 = _d_y0; + _d_y0 -= _r_d2; + *_d_a += -_r_d2; + } + { + ax = _t2; + double _r_d1 = _d_ax; + _d_ax -= _r_d1; + double _r4 = _r_d1 * clad::custom_derivatives::exp_pushforward(ax, 1.).pushforward; + _d_ax += _r4; + } + if (_cond3) + _label3:; + { + ax = _t0; + double _r_d0 = _d_ax; + _d_ax -= _r_d0; + *_d_a += _r_d0 * _t1; + double _r2 = a * _r_d0 * clad::custom_derivatives::log_pushforward(x, 1.).pushforward; + *_d_x += _r2; + *_d_x += -_r_d0; + double _r3 = -_r_d0 * ::ROOT::Math::digamma(a); //numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a); + *_d_a += _r3; + } +} + +#endif // R__HAS_MATHMORE + } // namespace Math } // namespace ROOT diff --git a/math/mathcore/inc/Math/PdfFuncMathCore.h b/math/mathcore/inc/Math/PdfFuncMathCore.h index 1e4937706af77..0bb2b76718804 100644 --- a/math/mathcore/inc/Math/PdfFuncMathCore.h +++ b/math/mathcore/inc/Math/PdfFuncMathCore.h @@ -402,7 +402,7 @@ namespace Math { inline double gaussian_pdf(double x, double sigma = 1, double x0 = 0) { double tmp = (x-x0)/sigma; - return (1.0/(std::sqrt(2 * M_PI) * std::fabs(sigma))) * std::exp(-tmp*tmp/2); + return (1.0/(std::sqrt(2 * M_PI) * std::abs(sigma))) * std::exp(-tmp*tmp/2); } /** @@ -485,7 +485,7 @@ namespace Math { if ((x-x0) <= 0) return 0.0; double tmp = (std::log((x-x0)) - m)/s; - return 1.0 / ((x-x0) * std::fabs(s) * std::sqrt(2 * M_PI)) * std::exp(-(tmp * tmp) /2); + return 1.0 / ((x-x0) * std::abs(s) * std::sqrt(2 * M_PI)) * std::exp(-(tmp * tmp) /2); } @@ -510,7 +510,7 @@ namespace Math { // Inlined to enable clad-auto-derivation for this function. double tmp = (x-x0)/sigma; - return (1.0/(std::sqrt(2 * M_PI) * std::fabs(sigma))) * std::exp(-tmp*tmp/2); + return (1.0/(std::sqrt(2 * M_PI) * std::abs(sigma))) * std::exp(-tmp*tmp/2); } diff --git a/math/mathmore/inc/Math/SpecFuncMathMore.h b/math/mathmore/inc/Math/SpecFuncMathMore.h index 6670c0814d5da..564729efbb39f 100644 --- a/math/mathmore/inc/Math/SpecFuncMathMore.h +++ b/math/mathmore/inc/Math/SpecFuncMathMore.h @@ -873,6 +873,7 @@ namespace Math { double wigner_9j(int two_ja, int two_jb, int two_jc, int two_jd, int two_je, int two_jf, int two_jg, int two_jh, int two_ji); + double digamma(double x); } // namespace Math diff --git a/math/mathmore/src/SpecFuncMathMore.cxx b/math/mathmore/src/SpecFuncMathMore.cxx index a3480a945f713..73ed6c98149fe 100644 --- a/math/mathmore/src/SpecFuncMathMore.cxx +++ b/math/mathmore/src/SpecFuncMathMore.cxx @@ -28,6 +28,7 @@ #include "gsl/gsl_sf_zeta.h" #include "gsl/gsl_sf_airy.h" #include "gsl/gsl_sf_coupling.h" +#include "gsl/gsl_sf_psi.h" namespace ROOT { @@ -474,5 +475,12 @@ double wigner_9j(int ja, int jb, int jc, int jd, int je, int jf, int jg, int jh, return gsl_sf_coupling_9j(ja,jb,jc,jd,je,jf,jg,jh,ji); } +// Psi (Digamma) Function. + +double digamma(double x) +{ + return gsl_sf_psi(x); +} + } // namespace Math } // namespace ROOT diff --git a/roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h b/roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h index 81b49b0fcbe90..7119cc4b9a55d 100644 --- a/roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h +++ b/roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h @@ -605,14 +605,14 @@ poissonIntegral(int code, double mu, double x, double integrandMin, double integ // Sum from 0 to just before the bin outside of the range. if (ixMin == 0) { - return ROOT::Math::gamma_cdf_c(mu, ixMax, 1); + return ROOT::Math::inc_gamma_c(ixMax, mu); } else { // If necessary, subtract from 0 to the beginning of the range if (ixMin <= mu) { - return ROOT::Math::gamma_cdf_c(mu, ixMax, 1) - ROOT::Math::gamma_cdf_c(mu, ixMin, 1); + return ROOT::Math::inc_gamma_c(ixMax, mu) - ROOT::Math::inc_gamma_c(ixMin, mu); } else { // Avoid catastrophic cancellation in the high tails: - return ROOT::Math::gamma_cdf(mu, ixMin, 1) - ROOT::Math::gamma_cdf(mu, ixMax, 1); + return ROOT::Math::inc_gamma(ixMin, mu) - ROOT::Math::inc_gamma(ixMax, mu); } } } @@ -621,7 +621,7 @@ poissonIntegral(int code, double mu, double x, double integrandMin, double integ // negative ix does not need protection (gamma returns 0.0) const double ix = 1 + x; - return ROOT::Math::gamma_cdf(integrandMax, ix, 1.0) - ROOT::Math::gamma_cdf(integrandMin, ix, 1.0); + return ROOT::Math::inc_gamma(ix, integrandMax) - ROOT::Math::inc_gamma(ix, integrandMin); } inline double logNormalIntegral(double xMin, double xMax, double m0, double k)