From 2d48ca9ac5ff979611859a227508ec948d5b038b Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Mon, 18 Nov 2024 20:27:42 +0100 Subject: [PATCH] [RF] Put codegen related functions in `RooFit::Experimental` --- roofit/codegen/inc/RooFit/CodegenImpl.h | 13 +++-- roofit/codegen/src/CodegenImpl.cxx | 14 +++--- roofit/roofitcore/inc/RooAbsArg.h | 2 + roofit/roofitcore/inc/RooDataHist.h | 4 +- roofit/roofitcore/inc/RooFit/CodegenContext.h | 23 ++------- roofit/roofitcore/res/RooFitImplHelpers.h | 5 +- roofit/roofitcore/src/RooClassFactory.cxx | 22 +++------ roofit/roofitcore/src/RooDataHist.cxx | 4 +- .../roofitcore/src/RooFit/CodegenContext.cxx | 49 +++++++++++++++++-- roofit/roofitcore/src/RooFitImplHelpers.cxx | 9 ++++ roofit/roofitcore/src/RooFuncWrapper.cxx | 2 +- 11 files changed, 93 insertions(+), 54 deletions(-) diff --git a/roofit/codegen/inc/RooFit/CodegenImpl.h b/roofit/codegen/inc/RooFit/CodegenImpl.h index 48a3a6f9906d5..a0d22110a13e1 100644 --- a/roofit/codegen/inc/RooFit/CodegenImpl.h +++ b/roofit/codegen/inc/RooFit/CodegenImpl.h @@ -69,14 +69,16 @@ class RooNLLVarNew; class RooNormalizedPdf; } // namespace Detail +namespace Experimental { + class CodegenContext; -void codegenImpl(Detail::RooFixedProdPdf &arg, CodegenContext &ctx); -void codegenImpl(Detail::RooNLLVarNew &arg, CodegenContext &ctx); -void codegenImpl(Detail::RooNormalizedPdf &arg, CodegenContext &ctx); +void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx); +void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx); +void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx); void codegenImpl(ParamHistFunc &arg, CodegenContext &ctx); void codegenImpl(PiecewiseInterpolation &arg, CodegenContext &ctx); -void codegenImpl(RooAbsArg &arg, RooFit::CodegenContext &ctx); +void codegenImpl(RooAbsArg &arg, CodegenContext &ctx); void codegenImpl(RooAddPdf &arg, CodegenContext &ctx); void codegenImpl(RooAddition &arg, CodegenContext &ctx); void codegenImpl(RooBernstein &arg, CodegenContext &ctx); @@ -144,12 +146,13 @@ std::string codegenIntegralImpl(Arg_t &arg, int code, const char *rangeName, Cod template struct CodegenIntegralImplCaller { - static auto call(RooAbsReal &arg, int code, const char *rangeName, RooFit::CodegenContext &ctx) + static auto call(RooAbsReal &arg, int code, const char *rangeName, CodegenContext &ctx) { return codegenIntegralImpl(static_cast(arg), code, rangeName, ctx, PrioHighest{}); } }; +} // namespace Experimental } // namespace RooFit #endif diff --git a/roofit/codegen/src/CodegenImpl.cxx b/roofit/codegen/src/CodegenImpl.cxx index eec3f065c49d8..a4ef24b67f499 100644 --- a/roofit/codegen/src/CodegenImpl.cxx +++ b/roofit/codegen/src/CodegenImpl.cxx @@ -60,6 +60,7 @@ #include namespace RooFit { +namespace Experimental { namespace { @@ -123,7 +124,7 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R } // namespace -void codegenImpl(Detail::RooFixedProdPdf &arg, CodegenContext &ctx) +void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx) { if (arg.cache()._isRearranged) { ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.cache()._rearrangedNum, *arg.cache()._rearrangedDen)); @@ -218,7 +219,7 @@ void codegenImpl(PiecewiseInterpolation &arg, CodegenContext &ctx) /// /// \param[in] ctx An object to manage auxiliary information for code-squashing. Also takes the /// code string that this class outputs into the squashed code through the 'addToCodeBody' function. -void codegenImpl(RooAbsArg &arg, RooFit::CodegenContext &ctx) +void codegenImpl(RooAbsArg &arg, CodegenContext &ctx) { std::stringstream errorMsg; errorMsg << "Translate function for class \"" << arg.ClassName() << "\" has not yet been implemented."; @@ -251,7 +252,7 @@ void codegenImpl(RooAddition &arg, CodegenContext &ctx) std::size_t i = 0; for (auto *component : static_range_cast(arg.list())) { - if (!dynamic_cast(component) || arg.list().size() == 1) { + if (!dynamic_cast(component) || arg.list().size() == 1) { result += ctx.getResult(*component); ++i; if (i < arg.list().size()) @@ -396,7 +397,7 @@ void codegenImpl(RooLognormal &arg, CodegenContext &ctx) ctx.addResult(&arg, ctx.buildCall(mathFunc(funcName), arg.getX(), arg.getShapeK(), arg.getMedian())); } -void codegenImpl(Detail::RooNLLVarNew &arg, CodegenContext &ctx) +void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx) { if (arg.binnedL() && !arg.pdf().getAttribute("BinnedLikelihoodActiveYields")) { std::stringstream errorMsg; @@ -438,7 +439,7 @@ void codegenImpl(Detail::RooNLLVarNew &arg, CodegenContext &ctx) } } -void codegenImpl(Detail::RooNormalizedPdf &arg, CodegenContext &ctx) +void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx) { // For now just return function/normalization integral. ctx.addResult(&arg, ctx.getResult(arg.pdf()) + "/" + ctx.getResult(arg.normIntegral())); @@ -520,7 +521,7 @@ std::string codegenIntegral(RooAbsReal &arg, int code, const char *rangeName, Co } else { // Can probably done with CppInterop in the future to avoid string manipulation. std::stringstream cmd; - cmd << "&RooFit::CodegenIntegralImplCaller<" << tclass->GetName() << ">::call;"; + cmd << "&RooFit::Experimental::CodegenIntegralImplCaller<" << tclass->GetName() << ">::call;"; func = reinterpret_cast(gInterpreter->ProcessLine(cmd.str().c_str())); dispatchMap[tclass] = func; } @@ -844,4 +845,5 @@ std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName return std::to_string(arg.analyticalIntegral(code, rangeName)); } +} // namespace Experimental } // namespace RooFit diff --git a/roofit/roofitcore/inc/RooAbsArg.h b/roofit/roofitcore/inc/RooAbsArg.h index 655ed128dfbcf..c387bdf13b60b 100644 --- a/roofit/roofitcore/inc/RooAbsArg.h +++ b/roofit/roofitcore/inc/RooAbsArg.h @@ -54,8 +54,10 @@ using RooListProxy = RooCollectionProxy; class RooExpensiveObjectCache ; class RooWorkspace ; namespace RooFit { +namespace Experimental { class CodegenContext; } +} class RooRefArray : public TObjArray { public: diff --git a/roofit/roofitcore/inc/RooDataHist.h b/roofit/roofitcore/inc/RooDataHist.h index 54794b9ae36c5..f27ea2c8dc590 100644 --- a/roofit/roofitcore/inc/RooDataHist.h +++ b/roofit/roofitcore/inc/RooDataHist.h @@ -218,9 +218,9 @@ class RooDataHist : public RooAbsData, public RooDirItem { double const* wgtErrHiArray() const { return _errHi; } double const* sumW2Array() const { return _sumw2; } - std::string calculateTreeIndexForCodeSquash(RooAbsArg const *klass, RooFit::CodegenContext &ctx, + std::string calculateTreeIndexForCodeSquash(RooAbsArg const *klass, RooFit::Experimental::CodegenContext &ctx, const RooAbsCollection &coords, bool reverse = false) const; - std::string declWeightArrayForCodeSquash(RooFit::CodegenContext &ctx, + std::string declWeightArrayForCodeSquash(RooFit::Experimental::CodegenContext &ctx, bool correctForBinSize) const; protected: diff --git a/roofit/roofitcore/inc/RooFit/CodegenContext.h b/roofit/roofitcore/inc/RooFit/CodegenContext.h index c65bc0294a449..8c34d36245e4f 100644 --- a/roofit/roofitcore/inc/RooFit/CodegenContext.h +++ b/roofit/roofitcore/inc/RooFit/CodegenContext.h @@ -31,6 +31,7 @@ template class RooTemplateProxy; namespace RooFit { +namespace Experimental { template struct Prio { @@ -226,27 +227,11 @@ std::string CodegenContext::buildArgSpanImpl(std::span arr) return arrName; } -template -void codegenImpl(Arg_t &arg, RooFit::CodegenContext &ctx, Prio

p) -{ - if constexpr (std::is_same, PrioLowest>::value) { - return codegenImpl(arg, ctx); - } else { - return codegenImpl(arg, ctx, p.next()); - } -} - -template -struct CodegenImplCaller { - - static auto call(RooAbsArg &arg, RooFit::CodegenContext &ctx) { - return codegenImpl(static_cast(arg), ctx, PrioHighest{}); - } - -}; +void declareDispatcherCode(std::string const &funcName); -void codegen(RooAbsArg &arg, RooFit::CodegenContext &ctx); +void codegen(RooAbsArg &arg, CodegenContext &ctx); +} // namespace Experimental } // namespace RooFit #endif diff --git a/roofit/roofitcore/res/RooFitImplHelpers.h b/roofit/roofitcore/res/RooFitImplHelpers.h index 97f14035992d1..2dbf612491906 100644 --- a/roofit/roofitcore/res/RooFitImplHelpers.h +++ b/roofit/roofitcore/res/RooFitImplHelpers.h @@ -16,9 +16,10 @@ #include #include -#include #include +#include #include +#include class RooAbsPdf; class RooAbsData; @@ -100,6 +101,8 @@ namespace Detail { std::string makeValidVarName(std::string const &in); +void replaceAll(std::string &inOut, std::string_view what, std::string_view with); + } // namespace Detail } // namespace RooFit diff --git a/roofit/roofitcore/src/RooClassFactory.cxx b/roofit/roofitcore/src/RooClassFactory.cxx index a6fd99cea777e..81ae578d28c52 100644 --- a/roofit/roofitcore/src/RooClassFactory.cxx +++ b/roofit/roofitcore/src/RooClassFactory.cxx @@ -39,6 +39,7 @@ instantiate objects. #include "RooWorkspace.h" #include "RooGlobalFunc.h" #include "RooAbsPdf.h" +#include "RooFitImplHelpers.h" #include @@ -401,15 +402,6 @@ std::string getFromVarSpans(std::vector const &alist) return ss.str(); } -/// Replace all occurrences of `what` with `with` inside of `inOut`. -void replaceAll(std::string &inOut, std::string_view what, std::string_view with) -{ - for (std::string::size_type pos{}; inOut.npos != (pos = inOut.find(what.data(), pos, what.length())); - pos += with.length()) { - inOut.replace(pos, what.length(), with.data(), with.length()); - } -} - inline bool isSpecial(char c) { return c != '_' && !std::isalnum(c); @@ -560,9 +552,11 @@ class CLASS_NAME : public BASE_NAME { }; namespace RooFit { +namespace Experimental { void codegenImpl(CLASS_NAME &arg, CodegenContext &ctx); +} // namespace Experimental } // namespace RooFit )"; @@ -687,7 +681,7 @@ CLASS_NAME::CLASS_NAME(const char *name, const char *title, } } - cf << "void RooFit::codegenImpl(CLASS_NAME &arg, RooFit::CodegenContext &ctx)\n" + cf << "void RooFit::Experimental::codegenImpl(CLASS_NAME &arg, RooFit::Experimental::CodegenContext &ctx)\n" << "{\n" << " ctx.addResult(&arg, ctx.buildCall(\"CLASS_NAME_evaluate\", " << varsGetters.str() << "));\n" <<"}\n"; @@ -796,10 +790,10 @@ void CLASS_NAME::generateEvent(int code) std::ofstream ocf(className + ".cxx"); std::string headerCode = hf.str(); std::string sourceCode = cf.str(); - replaceAll(headerCode, "CLASS_NAME", className); - replaceAll(sourceCode, "CLASS_NAME", className); - replaceAll(headerCode, "BASE_NAME", baseName); - replaceAll(sourceCode, "BASE_NAME", baseName); + RooFit::Detail::replaceAll(headerCode, "CLASS_NAME", className); + RooFit::Detail::replaceAll(sourceCode, "CLASS_NAME", className); + RooFit::Detail::replaceAll(headerCode, "BASE_NAME", baseName); + RooFit::Detail::replaceAll(sourceCode, "BASE_NAME", baseName); ohf << headerCode; ocf << sourceCode; diff --git a/roofit/roofitcore/src/RooDataHist.cxx b/roofit/roofitcore/src/RooDataHist.cxx index b216c2d0c270c..c0559817e86f2 100644 --- a/roofit/roofitcore/src/RooDataHist.cxx +++ b/roofit/roofitcore/src/RooDataHist.cxx @@ -992,7 +992,7 @@ Int_t RooDataHist::getIndex(const RooAbsCollection& coord, bool fast) const { return calcTreeIndex(coord, fast); } -std::string RooDataHist::declWeightArrayForCodeSquash(RooFit::CodegenContext &ctx, +std::string RooDataHist::declWeightArrayForCodeSquash(RooFit::Experimental::CodegenContext &ctx, bool correctForBinSize) const { std::vector vals(_arrSize); @@ -1008,7 +1008,7 @@ std::string RooDataHist::declWeightArrayForCodeSquash(RooFit::CodegenContext &ct return ctx.buildArg(vals); } -std::string RooDataHist::calculateTreeIndexForCodeSquash(RooAbsArg const * /*klass*/, RooFit::CodegenContext &ctx, +std::string RooDataHist::calculateTreeIndexForCodeSquash(RooAbsArg const * /*klass*/, RooFit::Experimental::CodegenContext &ctx, const RooAbsCollection &coords, bool reverse) const { assert(coords.size() == _vars.size()); diff --git a/roofit/roofitcore/src/RooFit/CodegenContext.cxx b/roofit/roofitcore/src/RooFit/CodegenContext.cxx index 44238a5b0ff2e..a993fd423f05e 100644 --- a/roofit/roofitcore/src/RooFit/CodegenContext.cxx +++ b/roofit/roofitcore/src/RooFit/CodegenContext.cxx @@ -32,6 +32,7 @@ bool startsWith(std::string_view str, std::string_view prefix) } // namespace namespace RooFit { +namespace Experimental { /// @brief Adds (or overwrites) the string representing the result of a node. /// @param key The name of the node to add the result for. @@ -73,7 +74,7 @@ std::string const &CodegenContext::getResult(RooAbsArg const &arg) } // Now, recursively call translate into the current argument to load the correct result. - RooFit::codegen(const_cast(arg), *this); + codegen(const_cast(arg), *this); return _nodeNames.at(arg.namePtr()); } @@ -309,9 +310,48 @@ CodegenContext::buildFunction(RooAbsArg const &arg, std::map +auto FUNC_NAME(Arg_t &arg, CodegenContext &ctx, Prio

p) +{ + if constexpr (std::is_same, PrioLowest>::value) { + return FUNC_NAME(arg, ctx); + } else { + return FUNC_NAME(arg, ctx, p.next()); + } +} + +template +struct Caller_FUNC_NAME { + + static auto call(RooAbsArg &arg, CodegenContext &ctx) + { + return FUNC_NAME(static_cast(arg), ctx, PrioHighest{}); + } +}; + +} // namespace Experimental +} // namespace RooFit + )"; + + RooFit::Detail::replaceAll(dispatcherCode, "FUNC_NAME", funcName); + gInterpreter->Declare(dispatcherCode.c_str()); +} + +void codegen(RooAbsArg &arg, CodegenContext &ctx) +{ + static bool codeDeclared = false; + if (!codeDeclared) { + declareDispatcherCode("codegenImpl"); + codeDeclared = true; + } + + using Func = void (*)(RooAbsArg &, CodegenContext &); Func func; @@ -327,7 +367,7 @@ void codegen(RooAbsArg &arg, RooFit::CodegenContext &ctx) } else { // Can probably done with CppInterop in the future to avoid string manipulation. std::stringstream cmd; - cmd << "&RooFit::CodegenImplCaller<" << tclass->GetName() << ">::call;"; + cmd << "&RooFit::Experimental::Caller_codegenImpl<" << tclass->GetName() << ">::call;"; func = reinterpret_cast(gInterpreter->ProcessLine(cmd.str().c_str())); dispatchMap[tclass] = func; } @@ -335,4 +375,5 @@ void codegen(RooAbsArg &arg, RooFit::CodegenContext &ctx) return func(arg, ctx); } +} // namespace Experimental } // namespace RooFit diff --git a/roofit/roofitcore/src/RooFitImplHelpers.cxx b/roofit/roofitcore/src/RooFitImplHelpers.cxx index 19a24a4bf1252..80c64edd079ac 100644 --- a/roofit/roofitcore/src/RooFitImplHelpers.cxx +++ b/roofit/roofitcore/src/RooFitImplHelpers.cxx @@ -306,6 +306,15 @@ std::string makeValidVarName(std::string const &in) return out; } +/// Replace all occurrences of `what` with `with` inside of `inOut`. +void replaceAll(std::string &inOut, std::string_view what, std::string_view with) +{ + for (std::string::size_type pos{}; inOut.npos != (pos = inOut.find(what.data(), pos, what.length())); + pos += with.length()) { + inOut.replace(pos, what.length(), with.data(), with.length()); + } +} + } // namespace Detail } // namespace RooFit diff --git a/roofit/roofitcore/src/RooFuncWrapper.cxx b/roofit/roofitcore/src/RooFuncWrapper.cxx index 12f9163435739..72a554b59d6d1 100644 --- a/roofit/roofitcore/src/RooFuncWrapper.cxx +++ b/roofit/roofitcore/src/RooFuncWrapper.cxx @@ -71,7 +71,7 @@ RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal & return found != spans.end() ? found->second.size() : -1; }); - RooFit::CodegenContext ctx; + RooFit::Experimental::CodegenContext ctx; // First update the result variable of params in the compute graph to in[]. int idx = 0;