Skip to content

Commit

Permalink
[RF][HF] Clearly mark interpolation code 3 as unknown
Browse files Browse the repository at this point in the history
As noted in GitHub issue #7103, the interpolation code 3 is the same as
code 2 in the `FlexibleInterpVar` and the `PiecewiseInterpolation`
classes.

According to some comments in the source code, interpolation code 3 was
meant to be "a parabolic version of log-normal".

There were two options to fix this:

1) Actually implement this parabolic interpolation with linear
   extrapolation, analogous to code 2 but in log space.

2) Clearly mark interpolation code 3 as non-existing.

This commit implements solution 2, because the interpolation code 3 is
not mentioned anywhere outside the ROOT source code. Especially not is
the HistFactory paper:
https://cds.cern.ch/record/1456844/files/CERN-OPEN-2012-016.pdf

Implementing a new interpolation scheme that apparently was
not needed in the last 10 years anyway would increase the burden on the
user to understand all these different settings unnecessarily.

Closes #7103.
  • Loading branch information
guitargeek committed Nov 15, 2024
1 parent 5027f07 commit 2aa2426
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ namespace HistFactory{

double evaluate() const override;

private:

void setInterpCodeForParam(int iParam, int code);

ClassDefOverride(RooStats::HistFactory::FlexibleInterpVar,2); // flexible interpolation
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PiecewiseInterpolation : public RooAbsReal {
void setPositiveDefinite(bool flag=true){_positiveDefinite=flag;}
bool positiveDefinite() const {return _positiveDefinite;}

void setInterpCode(RooAbsReal& param, int code, bool silent=false);
void setInterpCode(RooAbsReal& param, int code, bool silent=true);
void setAllInterpCodes(int code);
void printAllInterpCodes();

Expand Down Expand Up @@ -102,6 +102,10 @@ class PiecewiseInterpolation : public RooAbsReal {
double evaluate() const override;
void doEval(RooFit::EvalContext &) const override;

private:

void setInterpCodeForParam(int iParam, int code);

ClassDefOverride(PiecewiseInterpolation,4) // Sum of RooAbsReal objects
};

Expand Down
77 changes: 41 additions & 36 deletions roofit/histfactory/src/FlexibleInterpVar.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ FlexibleInterpVar::FlexibleInterpVar(const char* name, const char* title,
FlexibleInterpVar::FlexibleInterpVar(const char* name, const char* title,
const RooArgList& paramList,
double argNominal, std::vector<double> const& lowVec, std::vector<double> const& highVec,
std::vector<int> const& code) :
std::vector<int> const& codes) :
RooAbsReal(name, title),
_paramList("paramList","List of paramficients",this),
_nominal(argNominal), _low(lowVec), _high(highVec), _interpCode(code)
_nominal(argNominal), _low(lowVec), _high(highVec)
{
for (auto param : paramList) {
if (!dynamic_cast<RooAbsReal*>(param)) {
Expand All @@ -69,6 +69,11 @@ FlexibleInterpVar::FlexibleInterpVar(const char* name, const char* title,
_paramList.add(*param) ;
}

_interpCode.resize(_paramList.size());
for (std::size_t i = 0; i < codes.size(); ++i) {
setInterpCodeForParam(i, codes[i]);
}

if (_low.size() != _paramList.size() || _low.size() != _high.size() || _low.size() != _interpCode.size()) {
coutE(InputArguments) << "FlexibleInterpVar::ctor(" << GetName() << ") invalid input std::vectors " << std::endl;
R__ASSERT(_low.size() == _paramList.size());
Expand Down Expand Up @@ -109,31 +114,43 @@ FlexibleInterpVar::~FlexibleInterpVar()
TRACE_DESTROY;
}


////////////////////////////////////////////////////////////////////////////////

void FlexibleInterpVar::setInterpCode(RooAbsReal& param, int code){
int index = _paramList.index(&param);
if(index<0){
coutE(InputArguments) << "FlexibleInterpVar::setInterpCode ERROR: " << param.GetName()
<< " is not in list" << std::endl;
} else if(_interpCode.at(index) != code){
coutI(InputArguments) << "FlexibleInterpVar::setInterpCode : " << param.GetName()
<< " is now " << code << std::endl;
_interpCode.at(index) = code;
// GHL: Adding suggestion by Swagato:
setValueDirty();
}
void FlexibleInterpVar::setInterpCode(RooAbsReal &param, int code)
{
int index = _paramList.index(&param);
if (index < 0) {
coutE(InputArguments) << "FlexibleInterpVar::setInterpCode ERROR: " << param.GetName() << " is not in list"
<< std::endl;
return;
}
setInterpCodeForParam(index, code);
}

////////////////////////////////////////////////////////////////////////////////
void FlexibleInterpVar::setAllInterpCodes(int code)
{
for (std::size_t i = 0; i < _interpCode.size(); ++i) {
setInterpCodeForParam(i, code);
}
}

void FlexibleInterpVar::setAllInterpCodes(int code){
for(unsigned int i=0; i<_interpCode.size(); ++i){
_interpCode.at(i) = code;
}
// GHL: Adding suggestion by Swagato:
setValueDirty();
void FlexibleInterpVar::setInterpCodeForParam(int iParam, int code)
{
RooAbsArg const &param = _paramList[iParam];
if (code < 0 || code > 5) {
coutE(InputArguments) << "FlexibleInterpVar::setInterpCode ERROR: " << param.GetName()
<< " with unknown interpolation code " << code << ", keeping current code "
<< _interpCode[iParam] << std::endl;
return;
}
if (code == 3) {
// In the past, code 3 was equivalent to code 2, which confused users.
// Now, we just say that code 3 doesn't exist and default to code 2 in
// that case for backwards compatible behavior.
coutE(InputArguments) << "FlexibleInterpVar::setInterpCode ERROR: " << param.GetName()
<< " with unknown interpolation code " << code << ", defaulting to code 2" << std::endl;
code = 2;
}
_interpCode.at(iParam) = code;
setValueDirty();
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -198,10 +215,6 @@ double FlexibleInterpVar::evaluate() const
double total(_nominal);
for (std::size_t i = 0; i < _paramList.size(); ++i) {
int code = _interpCode[i];
if (code < 0 || code > 4) {
coutE(InputArguments) << "FlexibleInterpVar::evaluate ERROR: param " << i
<< " with unknown interpolation code" << std::endl;
}
// To get consistent codes with the PiecewiseInterpolation
if (code == 4) {
code = 5;
Expand All @@ -223,10 +236,6 @@ void FlexibleInterpVar::translate(RooFit::Detail::CodeSquashContext &ctx) const
unsigned int n = _interpCode.size();

int interpCode = _interpCode[0];
if (interpCode < 0 || interpCode > 4) {
coutE(InputArguments) << "FlexibleInterpVar::evaluate ERROR: param " << 0
<< " with unknown interpolation code" << std::endl;
}
// To get consistent codes with the PiecewiseInterpolation
if (interpCode == 4) {
interpCode = 5;
Expand All @@ -251,10 +260,6 @@ void FlexibleInterpVar::doEval(RooFit::EvalContext &ctx) const

for (std::size_t i = 0; i < _paramList.size(); ++i) {
int code = _interpCode[i];
if (code < 0 || code > 4) {
coutE(InputArguments) << "FlexibleInterpVar::evaluate ERROR: param " << i
<< " with unknown interpolation code" << std::endl;
}
// To get consistent codes with the PiecewiseInterpolation
if (code == 4) {
code = 5;
Expand Down
2 changes: 1 addition & 1 deletion roofit/histfactory/src/HistoToWorkspaceFactoryFast.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ RooArgList HistoToWorkspaceFactoryFast::createObservables(const TH1 *hist, RooWo
assert(lowVec.size() == params.size());

FlexibleInterpVar interp( (interpName).c_str(), "", params, 1., lowVec, highVec);
interp.setAllInterpCodes(4); // LM: change to 4 (piece-wise linear to 6th order polynomial interpolation + linear extrapolation )
interp.setAllInterpCodes(4); // LM: change to 4 (piece-wise exponential to 6th order polynomial interpolation + exponential extrapolation )
//interp.setAllInterpCodes(0); // simple linear interpolation
proto.import(interp); // params have already been imported in first loop of this function
} else{
Expand Down
86 changes: 40 additions & 46 deletions roofit/histfactory/src/PiecewiseInterpolation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,8 @@ double PiecewiseInterpolation::evaluate() const
auto param = static_cast<RooAbsReal*>(_paramSet.at(i));
auto low = static_cast<RooAbsReal*>(_lowSet.at(i));
auto high = static_cast<RooAbsReal*>(_highSet.at(i));
Int_t icode = _interpCode[i] ;

if(icode < 0 || icode > 5) {
coutE(InputArguments) << "PiecewiseInterpolation::evaluate ERROR: " << param->GetName()
<< " with unknown interpolation code" << icode << endl ;
}
using RooFit::Detail::MathFuncs::flexibleInterpSingle;
sum += flexibleInterpSingle(icode, low->getVal(), high->getVal(), 1.0, nominal, param->getVal(), sum);
sum += flexibleInterpSingle(_interpCode[i], low->getVal(), high->getVal(), 1.0, nominal, param->getVal(), sum);
}

if(_positiveDefinite && (sum<0)){
Expand All @@ -190,10 +184,6 @@ void PiecewiseInterpolation::translate(RooFit::Detail::CodeSquashContext &ctx) c

std::string resName = "total_" + ctx.getTmpVarName();
for (std::size_t i = 0; i < n; ++i) {
if (_interpCode[i] < 0 || _interpCode[i] > 5) {
coutE(InputArguments) << "PiecewiseInterpolation::evaluate ERROR: " << _paramSet[i].GetName()
<< " with unknown interpolation code" << _interpCode[i] << endl;
}
if (_interpCode[i] != _interpCode[0]) {
coutE(InputArguments) << "FlexibleInterpVar::evaluate ERROR: Code Squashing AD does not yet support having "
"different interpolation codes for the same class object "
Expand Down Expand Up @@ -277,18 +267,10 @@ void PiecewiseInterpolation::doEval(RooFit::EvalContext &ctx) const
auto param = ctx.at(_paramSet.at(i));
auto low = ctx.at(_lowSet.at(i));
auto high = ctx.at(_highSet.at(i));
const int icode = _interpCode[i];

if (icode < 0 || icode > 5) {
coutE(InputArguments) << "PiecewiseInterpolation::doEval(): " << _paramSet[i].GetName()
<< " with unknown interpolation code" << icode << std::endl;
throw std::invalid_argument("PiecewiseInterpolation::doEval() got invalid interpolation code " +
std::to_string(icode));
}

for (std::size_t j = 0; j < sum.size(); ++j) {
using RooFit::Detail::MathFuncs::flexibleInterpSingle;
sum[j] += flexibleInterpSingle(icode, broadcast(low, j), broadcast(high, j), 1.0, broadcast(nominal, j),
sum[j] += flexibleInterpSingle(_interpCode[i], broadcast(low, j), broadcast(high, j), 1.0, broadcast(nominal, j),
broadcast(param, j), sum[j]);
}
}
Expand Down Expand Up @@ -347,7 +329,7 @@ Int_t PiecewiseInterpolation::getAnalyticalIntegralWN(RooArgSet& allVars, RooArg


// KC: check if interCode=0 for all
for (auto it = _paramSet.begin(); it != _paramSet.end(); ++it) {
for (auto it = _paramSet.begin(); it != _paramSet.end(); ++it) {
if (!_interpCode.empty() && _interpCode[it - _paramSet.begin()] != 0) {
// can't factorize integral
cout << "can't factorize integral" << endl;
Expand All @@ -371,7 +353,7 @@ Int_t PiecewiseInterpolation::getAnalyticalIntegralWN(RooArgSet& allVars, RooArg
// Make list of function projection and normalization integrals
RooAbsReal *func ;

// do variations
// do variations
for (auto it = _paramSet.begin(); it != _paramSet.end(); ++it)
{
auto i = it - _paramSet.begin();
Expand Down Expand Up @@ -491,12 +473,12 @@ double PiecewiseInterpolation::analyticalIntegralWN(Int_t code, const RooArgSet*

// now get low/high variations
// KC: old interp code with new iterator

i = 0;
for (auto const *param : static_range_cast<RooAbsReal *>(_paramSet)) {
low = static_cast<RooAbsReal *>(cache->_lowIntList.at(i));
high = static_cast<RooAbsReal *>(cache->_highIntList.at(i));

if(param->getVal() > 0) {
value += param->getVal()*(high->getVal() - nominal);
} else {
Expand Down Expand Up @@ -573,32 +555,44 @@ double PiecewiseInterpolation::analyticalIntegralWN(Int_t code, const RooArgSet*
return value;
}


////////////////////////////////////////////////////////////////////////////////

void PiecewiseInterpolation::setInterpCode(RooAbsReal& param, int code, bool silent){
int index = _paramSet.index(&param);
if(index<0){
coutE(InputArguments) << "PiecewiseInterpolation::setInterpCode ERROR: " << param.GetName()
<< " is not in list" << endl ;
} else {
if(!silent){
coutW(InputArguments) << "PiecewiseInterpolation::setInterpCode : " << param.GetName()
<< " is now " << code << endl ;
}
_interpCode.at(index) = code;
}
void PiecewiseInterpolation::setInterpCode(RooAbsReal &param, int code, bool /*silent*/)
{
int index = _paramSet.index(&param);
if (index < 0) {
coutE(InputArguments) << "PiecewiseInterpolation::setInterpCode ERROR: " << param.GetName() << " is not in list"
<< std::endl;
return;
}
setInterpCodeForParam(index, code);
}


////////////////////////////////////////////////////////////////////////////////

void PiecewiseInterpolation::setAllInterpCodes(int code){
for(unsigned int i=0; i<_interpCode.size(); ++i){
_interpCode.at(i) = code;
}
void PiecewiseInterpolation::setAllInterpCodes(int code)
{
for (std::size_t i = 0; i < _interpCode.size(); ++i) {
setInterpCodeForParam(i, code);
}
}

void PiecewiseInterpolation::setInterpCodeForParam(int iParam, int code)
{
RooAbsArg const &param = _paramSet[iParam];
if (code < 0 || code > 5) {
coutE(InputArguments) << "PiecewiseInterpolation::setInterpCode ERROR: " << param.GetName()
<< " with unknown interpolation code " << code << ", keeping current code "
<< _interpCode[iParam] << std::endl;
return;
}
if (code == 3) {
// In the past, code 3 was equivalent to code 2, which confused users.
// Now, we just say that code 3 doesn't exist and default to code 2 in
// that case for backwards compatible behavior.
coutE(InputArguments) << "PiecewiseInterpolation::setInterpCode ERROR: " << param.GetName()
<< " with unknown interpolation code " << code << ", defaulting to code 2" << std::endl;
code = 2;
}
_interpCode.at(iParam) = code;
setValueDirty();
}

////////////////////////////////////////////////////////////////////////////////

Expand Down
16 changes: 4 additions & 12 deletions roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,10 @@ inline double flexibleInterpSingle(unsigned int code, double low, double high, d
} else {
return a * std::pow(paramVal, 2) + b * paramVal + c;
}
} else if (code == 3) {
// parabolic version of log-normal
double a = 0.5 * (high + low) - nominal;
double b = 0.5 * (high - low);
double c = 0;
if (paramVal > 1) {
return (2 * a + b) * (paramVal - 1) + high - nominal;
} else if (paramVal < -1) {
return -1 * (2 * a - b) * (paramVal + 1) + low - nominal;
} else {
return a * std::pow(paramVal, 2) + b * paramVal + c;
}
// According to an old comment in the source code, code 3 was apparently
// meant to be a "parabolic version of log-normal", but it never got
// implemented. If someone would need it, it could be implemented as doing
// code 2 in log space.
} else if (code == 4) {
double x = paramVal;
if (x >= boundary) {
Expand Down

0 comments on commit 2aa2426

Please sign in to comment.