Skip to content

Commit

Permalink
[math] Some code simplifications in CladDerivator.h
Browse files Browse the repository at this point in the history
In particular, to get rid of `goto` statements.
  • Loading branch information
guitargeek committed Oct 28, 2024
1 parent 47b6ce6 commit 94fc289
Showing 1 changed file with 37 additions and 74 deletions.
111 changes: 37 additions & 74 deletions math/mathcore/inc/Math/CladDerivator.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ ValueAndPushforward<Double_t, Double_t> Ln10_pushforward()
#endif
} // namespace TMath


namespace ROOT {
namespace Math {

Expand Down Expand Up @@ -297,7 +296,6 @@ inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, do
double _d_v = 0;
double _d_denlan = 0;
if (v < -5.5) {
double _t0;
double u = ::std::exp(v + 1.);
double _d_u = 0;
if (u >= 1.e-10) {
Expand Down Expand Up @@ -326,15 +324,14 @@ inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, do
double _r_d1 = _d_ue;
_d_ue -= _r_d1;
double _r2 = 0;
_r2 += _r_d1 * clad::custom_derivatives::exp_pushforward(-1 / u, 1.).pushforward;
_r2 += _r_d1 * ::std::exp(-1 / u);
double _r3 = _r2 * -(-1 / (u * u));
_d_u += _r3;
}
u = _t0;
double _r_d0 = _d_u;
_d_u -= _r_d0;
double _r1 = 0;
_r1 += _r_d0 * clad::custom_derivatives::exp_pushforward(v + 1., 1.).pushforward;
_r1 += _r_d0 * ::std::exp(v + 1.);
_d_v += _r1;
} else if (v < -1) {
double _t4;
Expand All @@ -351,8 +348,7 @@ inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, do
double _r_d5 = _d_denlan;
_d_denlan -= _r_d5;
double _r7 = 0;
_r7 += _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * _t7 *
clad::custom_derivatives::exp_pushforward(-u, 1.).pushforward;
_r7 += _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * _t7 * ::std::exp(-u);
_d_u += -_r7;
double _r8 = 0;
_r8 += _t8 * _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) *
Expand All @@ -371,7 +367,7 @@ inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, do
double _r_d4 = _d_u;
_d_u -= _r_d4;
double _r6 = 0;
_r6 += _r_d4 * clad::custom_derivatives::exp_pushforward(-v - 1, 1.).pushforward;
_r6 += _r_d4 * ::std::exp(-v - 1);
_d_v += -_r6;
} else if (v < 1) {
double _t9;
Expand Down Expand Up @@ -514,7 +510,7 @@ inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, do
_d_v += _r18;
_d_v += -_r18 / _t24 * _t25;
double _r19 = 0;
_r19 += v * -_r18 / _t24 * clad::custom_derivatives::log_pushforward(v, 1.).pushforward;
_r19 += v * -_r18 / _t24 / v;
_d_v += _r19;
double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
_d_v += _r20;
Expand Down Expand Up @@ -559,8 +555,7 @@ inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, do
double _t3 = ::std::exp(-1. / u);
double _t2 = ::std::sqrt(u);
double _r2 = 0;
_r2 += _const0 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) * _t2 *
clad::custom_derivatives::exp_pushforward(-1. / u, 1.).pushforward;
_r2 += _const0 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) * _t2 * ::std::exp(-1. / u);
double _r3 = _r2 * -(-1. / (u * u));
_d_u += _r3;
double _r4 = 0;
Expand All @@ -570,16 +565,15 @@ inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, do
_d_u += a1[3] * _const0 * _t3 * _t2 * d_out * u * u;
_d_u += (a1[2] + a1[3] * u) * _const0 * _t3 * _t2 * d_out * u;
_d_u += (a1[1] + (a1[2] + a1[3] * u) * u) * _const0 * _t3 * _t2 * d_out;
_d_v += _d_u * clad::custom_derivatives::exp_pushforward(v + 1, 1.).pushforward;
_d_v += _d_u * ::std::exp(v + 1);
} else if (v < -1) {
double _d_u = 0;
double u = ::std::exp(-v - 1);
double _t8 = ::std::exp(-u);
double _t7 = ::std::sqrt(u);
double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
double _r6 = 0;
_r6 += d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t7 *
clad::custom_derivatives::exp_pushforward(-u, 1.).pushforward;
_r6 += d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t7 * ::std::exp(-u);
_d_u += -_r6;
double _r7 = d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * -(_t8 / (_t7 * _t7));
double _r8 = 0;
Expand All @@ -594,7 +588,7 @@ inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, do
_d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
_d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
_d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
_d_v += -_d_u * clad::custom_derivatives::exp_pushforward(-v - 1, 1.).pushforward;
_d_v += -_d_u * ::std::exp(-v - 1);
} else if (v < 1) {
double _t10 = (q2[0] + (q2[1] + (q2[2] + q2[3] * v) * v) * v);
_d_v += p2[3] * d_out / _t10 * v * v;
Expand Down Expand Up @@ -666,7 +660,7 @@ inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, do
_d_v += _r18;
_d_v += -_r18 / _t24 * _t25;
double _r19 = 0;
_r19 += v * -_r18 / _t24 * clad::custom_derivatives::log_pushforward(v, 1.).pushforward;
_r19 += v * -_r18 / _t24 / v;
_d_v += _r19;
double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
_d_v += _r20;
Expand All @@ -693,42 +687,35 @@ inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, do
constexpr double kBiginv = 2.22044604925031308085e-16;

double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0;
bool _cond0;
bool _cond1;
bool _cond2;
double _t0;
double _t1;
bool _cond3;
double _t2;
double _t3;
double _t4;
double _t5;
unsigned long _t6;
clad::tape<double> _t7 = {};
clad::tape<double> _t8 = {};
clad::tape<double> _t9 = {};
double ans, ax, c, r;
_cond0 = a <= 0;
if (_cond0)
if (a <= 0)
return;
_cond1 = x <= 0;
if (_cond1)
if (x <= 0)
return;
_cond2 = (x > 1.) && (x > a);
if (_cond2) {
if ((x > 1.) && (x > a)) {
double _r0 = 0;
double _r1 = 0;
inc_gamma_c_pullback(a, x, -_d_y, &_r0, &_r1);
*_d_a += _r0;
*_d_x += _r1;
return;
}
_t0 = ax;
_t1 = ::std::log(x);
ax = a * _t1 - x - ::std::lgamma(a);
_cond3 = ax < -kMAXLOG;
if (_cond3)
goto _label3;
if (ax < -kMAXLOG) {
*_d_x += (a * _d_ax / x) - _d_ax;
*_d_a += _d_ax * (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
_d_ax = 0.;
return;
}
_t2 = ax;
ax = ::std::exp(ax);
_t3 = r;
Expand All @@ -737,7 +724,7 @@ inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, do
c = 1.;
_t5 = ans;
ans = 1.;
_t6 = 0;
unsigned long _t6 = 0;
do {
_t6++;
clad::push(_t7, r);
Expand Down Expand Up @@ -797,23 +784,13 @@ inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, do
double _r_d1 = _d_ax;
_d_ax -= _r_d1;
double _r4 = 0;
_r4 += _r_d1 * clad::custom_derivatives::exp_pushforward(ax, 1.).pushforward;
_r4 += _r_d1 * ::std::exp(ax);
_d_ax += _r4;
}
if (_cond3)
_label3:;
{
ax = _t0;
double _r_d0 = _d_ax;
_d_ax -= _r_d0;
*_d_a += _r_d0 * _t1;
double _r2 = 0;
_r2 += a * _r_d0 * clad::custom_derivatives::log_pushforward(x, 1.).pushforward;
*_d_x += _r2;
*_d_x += -_r_d0;
double _r3 = 0;
_r3 += -_r_d0 * ::ROOT::Math::digamma(a); //numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
*_d_a += _r3;
*_d_x += (a * _d_ax / x) - _d_ax;
*_d_a += _d_ax * (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
_d_ax = 0.;
}
}

Expand All @@ -830,12 +807,7 @@ inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a,

double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_yc = 0, _d_r = 0, _d_t = 0, _d_y0 = 0, _d_z = 0;
double _d_pk = 0, _d_pkm1 = 0, _d_pkm2 = 0, _d_qk = 0, _d_qkm1 = 0, _d_qkm2 = 0;
bool _cond0;
bool _cond1;
bool _cond2;
double _t0;
double _t1;
bool _cond3;
double _t2;
double _t3;
double _t4;
Expand Down Expand Up @@ -868,27 +840,26 @@ inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a,
clad::tape<double> _t33 = {};
double ans, ax, c, yc, r, t, y, z;
double pk, pkm1, pkm2, qk, qkm1, qkm2;
_cond0 = a <= 0;
if (_cond0)
if (a <= 0)
return;
_cond1 = x <= 0;
if (_cond1)
if (x <= 0)
return;
_cond2 = (x < 1.) || (x < a);
if (_cond2) {
if ((x < 1.) || (x < a)) {
double _r0 = 0;
double _r1 = 0;
inc_gamma_pullback(a, x, -_d_y, &_r0, &_r1);
*_d_a += _r0;
*_d_x += _r1;
return;
}
_t0 = ax;
_t1 = ::std::log(x);
ax = a * _t1 - x - ::std::lgamma(a);
_cond3 = ax < -kMAXLOG;
if (_cond3)
goto _label3;
if (ax < -kMAXLOG) {
*_d_x += a * _d_ax / x - _d_ax;
*_d_a += _d_ax * (_t1 -::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
_d_ax = 0.;
return;
}
_t2 = ax;
ax = ::std::exp(ax);
_t3 = y;
Expand Down Expand Up @@ -1141,21 +1112,13 @@ inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a,
ax = _t2;
double _r_d1 = _d_ax;
_d_ax -= _r_d1;
double _r4 = _r_d1 * clad::custom_derivatives::exp_pushforward(ax, 1.).pushforward;
double _r4 = _r_d1 * ::std::exp(ax);
_d_ax += _r4;
}
if (_cond3)
_label3:;
{
ax = _t0;
double _r_d0 = _d_ax;
_d_ax -= _r_d0;
*_d_a += _r_d0 * _t1;
double _r2 = a * _r_d0 * clad::custom_derivatives::log_pushforward(x, 1.).pushforward;
*_d_x += _r2;
*_d_x += -_r_d0;
double _r3 = -_r_d0 * ::ROOT::Math::digamma(a); //numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
*_d_a += _r3;
*_d_x += a * _d_ax / x - _d_ax;
*_d_a += _d_ax * (_t1 -::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
_d_ax = 0.;
}
}

Expand Down

0 comments on commit 94fc289

Please sign in to comment.