Skip to content

Commit

Permalink
[RF] Faster Hesse in RooFit by advertising which params are independent
Browse files Browse the repository at this point in the history
This reduces the time to run Hesse in the ATLAS Higgs benchmark from
123 s to 92 seconds.

Given that some models take hours for this, this is a significant
improvement for the user experience.

Further improvement is possible by analyzing the computation graph a bit
more to find more independent parameters (e.g., the different gammas for
stat uncertainties from different bins).
  • Loading branch information
guitargeek committed Sep 9, 2024
1 parent 5e4ba5d commit eae7b17
Show file tree
Hide file tree
Showing 20 changed files with 171 additions and 12 deletions.
10 changes: 10 additions & 0 deletions math/mathcore/inc/Math/Functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class Functor : public IBaseFunctionMultiDim {
// for multi-dimensional functions
unsigned int NDim() const override { return fDim; }

bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }

void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }

private :

inline double DoEval (const double * x) const override {
Expand All @@ -75,6 +79,7 @@ private :

unsigned int fDim;
std::function<double(double const *)> fFunc;
std::function<bool(int, int)> fVanishSecondDerivFunc;
};

/**
Expand Down Expand Up @@ -222,6 +227,10 @@ class GradFunctor : public IGradientFunctionMultiDim {
fGradFunc(x, g);
}

bool VanishingSecondDerivative(int i, int j) const override { return fVanishSecondDerivFunc ? fVanishSecondDerivFunc(i, j) : false; }

void SetVanishingSecondDerivativeFunc(std::function<bool(int, int)> func) { fVanishSecondDerivFunc = std::move(func); }

private :

inline double DoEval (const double * x) const override {
Expand All @@ -244,6 +253,7 @@ private :
std::function<double(const double *)> fFunc;
std::function<double(double const *, unsigned int)> fDerivFunc;
std::function<void(const double *, double*)> fGradFunc;
std::function<bool(int, int)> fVanishSecondDerivFunc;
};


Expand Down
5 changes: 5 additions & 0 deletions math/mathcore/inc/Math/IFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ namespace ROOT {
// if it inherits from ROOT::Math::IGradientFunctionMultiDim.
virtual bool HasGradient() const { return false; }

/// Indicate whether a given second order derivative with respect to
/// parameters i and j is always zero. This can help to avoid
/// expensive function calls in Hessian evaluations.
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }

private:

/// Implementation of the evaluation function. Must be implemented by derived classes.
Expand Down
2 changes: 2 additions & 0 deletions math/minuit2/inc/Minuit2/FCNAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class FCNAdapter : public FCNBase {
// forward interface
// virtual double operator()(int npar, double* params,int iflag = 4) const;

bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }

Check warning on line 53 in math/minuit2/inc/Minuit2/FCNAdapter.h

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

'VanishingSecondDerivative' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 53 in math/minuit2/inc/Minuit2/FCNAdapter.h

View workflow job for this annotation

GitHub Actions / mac13 ARM64 LLVM_ENABLE_ASSERTIONS=On, builtin_zlib=ON

'VanishingSecondDerivative' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 53 in math/minuit2/inc/Minuit2/FCNAdapter.h

View workflow job for this annotation

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'VanishingSecondDerivative' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

private:
const Function &fFunc;
double fUp;
Expand Down
5 changes: 5 additions & 0 deletions math/minuit2/inc/Minuit2/FCNBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class FCNBase : public GenericFunction {
virtual bool HasHessian() const { return false; }

virtual bool HasG2() const { return false; }

/// Indicate whether a given second order derivative with respect to
/// parameters i and j is always zero. This can help to avoid
/// expensive function calls in Hessian evaluations.
virtual bool VanishingSecondDerivative(int /*i*/, int /*j*/) const { return false; }
};

} // namespace Minuit2
Expand Down
2 changes: 2 additions & 0 deletions math/minuit2/inc/Minuit2/FCNGradAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class FCNGradAdapter : public FCNBase {

void SetErrorDef(double up) override { fUp = up; }

bool VanishingSecondDerivative(int i, int j) const { return fFunc.VanishingSecondDerivative(i, j); }

Check warning on line 122 in math/minuit2/inc/Minuit2/FCNGradAdapter.h

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

'VanishingSecondDerivative' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 122 in math/minuit2/inc/Minuit2/FCNGradAdapter.h

View workflow job for this annotation

GitHub Actions / mac13 ARM64 LLVM_ENABLE_ASSERTIONS=On, builtin_zlib=ON

'VanishingSecondDerivative' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

Check warning on line 122 in math/minuit2/inc/Minuit2/FCNGradAdapter.h

View workflow job for this annotation

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'VanishingSecondDerivative' overrides a member function but is not marked 'override' [-Winconsistent-missing-override]

private:
const Function &fFunc;
double fUp;
Expand Down
11 changes: 7 additions & 4 deletions math/minuit2/src/MnHesse.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,17 @@ MinimumState MnHesse::ComputeNumerical(const MnFcn &mfcn, const MinimumState &st
if ((i + 1) == j || in == startParIndexOffDiagonal)
x(i) += dirin(i);

x(j) += dirin(j);

double fs1 = mfcn(x);
if(!doCentralFD) {
if(mfcn.Fcn().VanishingSecondDerivative(i, j)) {
vhmat(i, j) = 0.;
} else if(!doCentralFD) {
x(j) += dirin(j);
double fs1 = mfcn(x);
double elem = (fs1 + amin - yy(i) - yy(j)) / (dirin(i) * dirin(j));
vhmat(i, j) = elem;
x(j) -= dirin(j);
} else {
x(j) += dirin(j);
double fs1 = mfcn(x);
// three more function evaluations required for central fd
x(i) -= dirin(i); x(i) -= dirin(i);double fs3 = mfcn(x);
x(j) -= dirin(j); x(j) -= dirin(j);double fs4 = mfcn(x);
Expand Down
3 changes: 2 additions & 1 deletion roofit/roofitcore/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFitCore
RooFit/TestStatistics/RooSubsidiaryL.h
RooFit/TestStatistics/RooSumL.h
RooFit/TestStatistics/RooUnbinnedL.h
RooFit/TestStatistics/buildLikelihood.h
RooFit/TestStatistics/SharedOffset.h
RooFit/TestStatistics/buildLikelihood.h
RooFit/VariableGroups.h
RooFitLegacy/RooCatTypeLegacy.h
RooFitLegacy/RooCategorySharedProperties.h
RooFitLegacy/RooTreeData.h
Expand Down
4 changes: 3 additions & 1 deletion roofit/roofitcore/inc/RooAbsArg.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ using RooListProxy = RooCollectionProxy<RooArgList>;
class RooExpensiveObjectCache ;
class RooWorkspace ;
namespace RooFit {

namespace Detail {
class CodeSquashContext;
}
class VariableGroups;
}

class RooRefArray : public TObjArray {
Expand Down Expand Up @@ -270,7 +272,7 @@ class RooAbsArg : public TNamed, public RooPrintable {
bool recursiveCheckObservables(const RooArgSet* nset) const ;
RooFit::OwningPtr<RooArgSet> getComponents() const ;


virtual void fillVariableGroups(RooFit::VariableGroups &out) const;

void attachArgs(const RooAbsCollection &set);
void attachDataSet(const RooAbsData &set);
Expand Down
2 changes: 2 additions & 0 deletions roofit/roofitcore/inc/RooAddition.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class RooAddition : public RooAbsReal {

void translate(RooFit::Detail::CodeSquashContext &ctx) const override;

void fillVariableGroups(RooFit::VariableGroups &out) const override;

protected:

RooArgList _ownedList ; ///< List of owned components
Expand Down
3 changes: 3 additions & 0 deletions roofit/roofitcore/inc/RooConstraintSum.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class RooConstraintSum : public RooAbsReal {
std::unique_ptr<RooAbsArg> compileForNormSet(RooArgSet const &normSet, RooFit::Detail::CompileContext & ctx) const override;

void translate(RooFit::Detail::CodeSquashContext &ctx) const override;

void fillVariableGroups(RooFit::VariableGroups &out) const override;

protected:

RooListProxy _set1 ; ///< Set of constraint terms
Expand Down
33 changes: 33 additions & 0 deletions roofit/roofitcore/inc/RooFit/VariableGroups.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef RooFit_VariableGroups_h
#define RooFit_VariableGroups_h

#include <TNamed.h>

#include <iostream>
#include <unordered_map>
#include <vector>

class TNamed;

namespace RooFit {

struct VariableGroups {

Check warning on line 14 in roofit/roofitcore/inc/RooFit/VariableGroups.h

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

'VariableGroups' defined as a struct here but previously declared as a class; this is valid, but may result in linker errors under the Microsoft C++ ABI [-Wmismatched-tags]

Check warning on line 14 in roofit/roofitcore/inc/RooFit/VariableGroups.h

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

'VariableGroups' defined as a struct here but previously declared as a class; this is valid, but may result in linker errors under the Microsoft C++ ABI [-Wmismatched-tags]

Check warning on line 14 in roofit/roofitcore/inc/RooFit/VariableGroups.h

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

'VariableGroups' defined as a struct here but previously declared as a class; this is valid, but may result in linker errors under the Microsoft C++ ABI [-Wmismatched-tags]

Check warning on line 14 in roofit/roofitcore/inc/RooFit/VariableGroups.h

View workflow job for this annotation

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'VariableGroups' defined as a struct here but previously declared as a class; this is valid, but may result in linker errors under the Microsoft C++ ABI [-Wmismatched-tags]

Check warning on line 14 in roofit/roofitcore/inc/RooFit/VariableGroups.h

View workflow job for this annotation

GitHub Actions / alma9-clang clang LLVM_ENABLE_ASSERTIONS=On, CMAKE_C_COMPILER=clang, CMAKE_CXX_COMPILER=clang++

'VariableGroups' defined as a struct here but previously declared as a class; this is valid, but may result in linker errors under the Microsoft C++ ABI [-Wmismatched-tags]

std::unordered_map<TNamed const*, std::vector<int>> groups;

inline void print() {
for (auto const& item : groups) {
std::cout << item.first->GetName() << " :";
for (int n : item.second) {
std::cout << " " << n;
}
std::cout << std::endl;
}
}

int currentIndex = 0;
};

}

#endif
20 changes: 19 additions & 1 deletion roofit/roofitcore/src/RooAbsArg.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ for single nodes.
#include <RooArgSet.h>
#include <RooConstVar.h>
#include <RooExpensiveObjectCache.h>
#include <RooFit/VariableGroups.h>
#include <RooHelpers.h>
#include "RooFitImplHelpers.h"
#include <RooListProxy.h>
#include <RooMsgService.h>
#include <RooRealIntegral.h>
Expand All @@ -87,6 +87,8 @@ for single nodes.
#include <RooVectorDataStore.h>
#include <RooWorkspace.h>

#include "RooFitImplHelpers.h"

#include <TBuffer.h>
#include <TClass.h>
#include <TVirtualStreamerInfo.h>
Expand Down Expand Up @@ -2566,3 +2568,19 @@ void RooAbsArg::setDataToken(std::size_t index)
}
_dataToken = index;
}

void RooAbsArg::fillVariableGroups(RooFit::VariableGroups &out) const
{
// Get the set of nodes in the computation graph. Do the detour via
// RooArgList to avoid deduplication done after adding each element.
RooArgSet serverSet;
RooArgList serverList;
treeNodeServerList(&serverList, nullptr, /*branches*/ false, /*leaves*/ true, /*valueOnly*/ false,
/*recurseFundamental*/ true);
serverSet.add(serverList.begin(), serverList.end());

for (RooAbsArg const *arg : serverSet) {
out.groups[arg->namePtr()].push_back(out.currentIndex);
}
out.currentIndex++;
}
7 changes: 7 additions & 0 deletions roofit/roofitcore/src/RooAddition.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,10 @@ std::list<double>* RooAddition::plotSamplingHint(RooAbsRealLValue& obs, double x
{
return RooRealSumPdf::plotSamplingHint(_set, obs, xlo, xhi);
}

void RooAddition::fillVariableGroups(RooFit::VariableGroups &out) const
{
for (RooAbsArg *arg : _set) {
arg->fillVariableGroups(out);
}
}
7 changes: 7 additions & 0 deletions roofit/roofitcore/src/RooConstraintSum.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,10 @@ bool RooConstraintSum::setData(RooAbsData const& data, bool /*cloneData=true*/)
}
return true;
}

void RooConstraintSum::fillVariableGroups(RooFit::VariableGroups &out) const
{
for (RooAbsArg *arg : _set1) {
arg->fillVariableGroups(out);
}
}
5 changes: 5 additions & 0 deletions roofit/roofitcore/src/RooEvaluatorWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,9 @@ bool RooEvaluatorWrapper::setData(RooAbsData &data, bool /*cloneData*/)
return true;
}

void RooEvaluatorWrapper::fillVariableGroups(RooFit::VariableGroups &out) const
{
_topNode->fillVariableGroups(out);
}

/// \endcond
2 changes: 2 additions & 0 deletions roofit/roofitcore/src/RooEvaluatorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class RooEvaluatorWrapper final : public RooAbsReal {
/// The RooFit::Evaluator is dealing with constant terms itself.
void constOptimizeTestStatistic(ConstOpCode /*opcode*/, bool /*doAlsoTrackingOpt*/) override {}

void fillVariableGroups(RooFit::VariableGroups &out) const override;

protected:
double evaluate() const override;

Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/src/RooMinimizer.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ RooMinimizer::RooMinimizer(RooAbsReal &function, Config const &cfg) : _cfg(cfg)
_fcn = std::make_unique<RooMinimizerFcn>(&function, this);
}
initMinimizerFcnDependentPart(function.defaultErrorLevel());
};
}

/// Initialize the part of the minimizer that is independent of the function to be minimized
void RooMinimizer::initMinimizerFirstPart()
Expand Down
54 changes: 50 additions & 4 deletions roofit/roofitcore/src/RooMinimizerFcn.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
#include "RooAbsArg.h"
#include "RooAbsPdf.h"
#include "RooArgSet.h"
#include "RooRealVar.h"
#include "RooMsgService.h"
#include "RooFit/VariableGroups.h"
#include "RooMinimizer.h"
#include "RooMsgService.h"
#include "RooNaNPacker.h"
#include "RooRealVar.h"

#include "Math/Functor.h"
#include "TMatrixDSym.h"
Expand All @@ -38,6 +39,23 @@ using std::cout, std::endl, std::setprecision;

namespace {

template <class InputIt1, class InputIt2>
bool intersect(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
{
while (first1 != last1 && first2 != last2) {
if (*first1 < *first2) {
++first1;
continue;
}
if (*first2 < *first1) {
++first2;
continue;
}
return true;
}
return false;
}

// Helper function that wraps RooAbsArg::getParameters and directly returns the
// output RooArgSet. To be used in the initializer list of the RooMinimizerFcn
// constructor.
Expand All @@ -54,11 +72,34 @@ RooArgSet getParameters(RooAbsReal const &funct)
RooMinimizerFcn::RooMinimizerFcn(RooAbsReal *funct, RooMinimizer *context)
: RooAbsMinimizerFcn(getParameters(*funct), context), _funct(funct)
{
RooFit::VariableGroups groups;
funct->fillVariableGroups(groups);

RooArgList const &parameters = *GetFloatParamList();

std::size_t nParams = parameters.size();

_secondDerivMask.resize(nParams * nParams);
for (std::size_t i = 0; i < nParams; ++i) {
_secondDerivMask[nParams * i + i] = 1;
for (std::size_t j = 0; j < i; ++j) {
// std::cout << parameters[i].GetName() << " " << parameters[j].GetName() << std::endl;
auto const &gr1 = groups.groups.at(parameters[i].namePtr());
auto const &gr2 = groups.groups.at(parameters[j].namePtr());
_secondDerivMask[nParams * i + j] = intersect(gr1.begin(), gr1.end(), gr2.begin(), gr2.end());
_secondDerivMask[nParams * j + i] = _secondDerivMask[nParams * i + j];
}
}

if (context->_cfg.useGradient && funct->hasGradient()) {
_multiGenFcn = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
auto functor = std::make_unique<ROOT::Math::GradFunctor>(this, &RooMinimizerFcn::operator(),
&RooMinimizerFcn::evaluateGradient, getNDim());
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
_multiGenFcn = std::move(functor);
} else {
_multiGenFcn = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
auto functor = std::make_unique<ROOT::Math::Functor>(std::cref(*this), getNDim());
functor->SetVanishingSecondDerivativeFunc([this](int i, int j) { return this->vanishingSecondDerivative(i, j); });
_multiGenFcn = std::move(functor);
}
}

Expand Down Expand Up @@ -132,3 +173,8 @@ void RooMinimizerFcn::setOffsetting(bool flag)
{
_funct->enableOffsetting(flag);
}

bool RooMinimizerFcn::vanishingSecondDerivative(int i, int j) const
{
return _secondDerivMask[_nDim * i + j] == 0;
}
3 changes: 3 additions & 0 deletions roofit/roofitcore/src/RooMinimizerFcn.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ class RooMinimizerFcn : public RooAbsMinimizerFcn {
double operator()(const double *x) const;
void evaluateGradient(const double *x, double *out) const;

bool vanishingSecondDerivative(int i, int j) const;

private:
RooAbsReal *_funct;
std::unique_ptr<ROOT::Math::IBaseFunctionMultiDim> _multiGenFcn;
std::vector<int> _secondDerivMask;
};

#endif
3 changes: 3 additions & 0 deletions roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class MinuitGradFunctor : public ROOT::Math::IMultiGradFunction {

bool returnsInMinuit2ParameterSpace() const override { return _fcn.returnsInMinuit2ParameterSpace(); }

// TODO: Implement this
bool VanishingSecondDerivative(int /*i*/, int /*j*/) const override { false; }

Check warning on line 47 in roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

View workflow job for this annotation

GitHub Actions / fedora39 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

statement has no effect [-Wunused-value]

Check warning on line 47 in roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

View workflow job for this annotation

GitHub Actions / fedora39 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

no return statement in function returning non-void [-Wreturn-type]

Check warning on line 47 in roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

View workflow job for this annotation

GitHub Actions / alma9 LLVM_ENABLE_ASSERTIONS=On, CMAKE_BUILD_TYPE=Debug

statement has no effect [-Wunused-value]

Check warning on line 47 in roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

View workflow job for this annotation

GitHub Actions / alma9 LLVM_ENABLE_ASSERTIONS=On, CMAKE_BUILD_TYPE=Debug

no return statement in function returning non-void [-Wreturn-type]

Check warning on line 47 in roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

expression result unused [-Wunused-value]

Check warning on line 47 in roofit/roofitcore/src/TestStatistics/MinuitFcnGrad.cxx

View workflow job for this annotation

GitHub Actions / mac14 ARM64 LLVM_ENABLE_ASSERTIONS=On, CMAKE_CXX_STANDARD=20

non-void function does not return a value [-Wreturn-type]

private:
double DoEval(const double *x) const override { return _fcn(x); }

Expand Down

0 comments on commit eae7b17

Please sign in to comment.