Skip to content

Commit

Permalink
[RF] Put codegen related functions in RooFit::Experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
guitargeek committed Nov 18, 2024
1 parent cfdeddf commit 2d48ca9
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 54 deletions.
13 changes: 8 additions & 5 deletions roofit/codegen/inc/RooFit/CodegenImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,16 @@ class RooNLLVarNew;
class RooNormalizedPdf;
} // namespace Detail

namespace Experimental {

class CodegenContext;

void codegenImpl(Detail::RooFixedProdPdf &arg, CodegenContext &ctx);
void codegenImpl(Detail::RooNLLVarNew &arg, CodegenContext &ctx);
void codegenImpl(Detail::RooNormalizedPdf &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx);
void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx);
void codegenImpl(ParamHistFunc &arg, CodegenContext &ctx);
void codegenImpl(PiecewiseInterpolation &arg, CodegenContext &ctx);
void codegenImpl(RooAbsArg &arg, RooFit::CodegenContext &ctx);
void codegenImpl(RooAbsArg &arg, CodegenContext &ctx);
void codegenImpl(RooAddPdf &arg, CodegenContext &ctx);
void codegenImpl(RooAddition &arg, CodegenContext &ctx);
void codegenImpl(RooBernstein &arg, CodegenContext &ctx);
Expand Down Expand Up @@ -144,12 +146,13 @@ std::string codegenIntegralImpl(Arg_t &arg, int code, const char *rangeName, Cod
template <class Arg_t>
struct CodegenIntegralImplCaller {

static auto call(RooAbsReal &arg, int code, const char *rangeName, RooFit::CodegenContext &ctx)
static auto call(RooAbsReal &arg, int code, const char *rangeName, CodegenContext &ctx)
{
return codegenIntegralImpl(static_cast<Arg_t &>(arg), code, rangeName, ctx, PrioHighest{});
}
};

} // namespace Experimental
} // namespace RooFit

#endif
14 changes: 8 additions & 6 deletions roofit/codegen/src/CodegenImpl.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <TInterpreter.h>

namespace RooFit {
namespace Experimental {

namespace {

Expand Down Expand Up @@ -123,7 +124,7 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R

} // namespace

void codegenImpl(Detail::RooFixedProdPdf &arg, CodegenContext &ctx)
void codegenImpl(RooFit::Detail::RooFixedProdPdf &arg, CodegenContext &ctx)
{
if (arg.cache()._isRearranged) {
ctx.addResult(&arg, ctx.buildCall(mathFunc("ratio"), *arg.cache()._rearrangedNum, *arg.cache()._rearrangedDen));
Expand Down Expand Up @@ -218,7 +219,7 @@ void codegenImpl(PiecewiseInterpolation &arg, CodegenContext &ctx)
///
/// \param[in] ctx An object to manage auxiliary information for code-squashing. Also takes the
/// code string that this class outputs into the squashed code through the 'addToCodeBody' function.
void codegenImpl(RooAbsArg &arg, RooFit::CodegenContext &ctx)
void codegenImpl(RooAbsArg &arg, CodegenContext &ctx)
{
std::stringstream errorMsg;
errorMsg << "Translate function for class \"" << arg.ClassName() << "\" has not yet been implemented.";
Expand Down Expand Up @@ -251,7 +252,7 @@ void codegenImpl(RooAddition &arg, CodegenContext &ctx)
std::size_t i = 0;
for (auto *component : static_range_cast<RooAbsReal *>(arg.list())) {

if (!dynamic_cast<Detail::RooNLLVarNew *>(component) || arg.list().size() == 1) {
if (!dynamic_cast<RooFit::Detail::RooNLLVarNew *>(component) || arg.list().size() == 1) {
result += ctx.getResult(*component);
++i;
if (i < arg.list().size())
Expand Down Expand Up @@ -396,7 +397,7 @@ void codegenImpl(RooLognormal &arg, CodegenContext &ctx)
ctx.addResult(&arg, ctx.buildCall(mathFunc(funcName), arg.getX(), arg.getShapeK(), arg.getMedian()));
}

void codegenImpl(Detail::RooNLLVarNew &arg, CodegenContext &ctx)
void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx)
{
if (arg.binnedL() && !arg.pdf().getAttribute("BinnedLikelihoodActiveYields")) {
std::stringstream errorMsg;
Expand Down Expand Up @@ -438,7 +439,7 @@ void codegenImpl(Detail::RooNLLVarNew &arg, CodegenContext &ctx)
}
}

void codegenImpl(Detail::RooNormalizedPdf &arg, CodegenContext &ctx)
void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx)
{
// For now just return function/normalization integral.
ctx.addResult(&arg, ctx.getResult(arg.pdf()) + "/" + ctx.getResult(arg.normIntegral()));
Expand Down Expand Up @@ -520,7 +521,7 @@ std::string codegenIntegral(RooAbsReal &arg, int code, const char *rangeName, Co
} else {
// Can probably done with CppInterop in the future to avoid string manipulation.
std::stringstream cmd;
cmd << "&RooFit::CodegenIntegralImplCaller<" << tclass->GetName() << ">::call;";
cmd << "&RooFit::Experimental::CodegenIntegralImplCaller<" << tclass->GetName() << ">::call;";
func = reinterpret_cast<Func>(gInterpreter->ProcessLine(cmd.str().c_str()));
dispatchMap[tclass] = func;
}
Expand Down Expand Up @@ -844,4 +845,5 @@ std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName
return std::to_string(arg.analyticalIntegral(code, rangeName));
}

} // namespace Experimental
} // namespace RooFit
2 changes: 2 additions & 0 deletions roofit/roofitcore/inc/RooAbsArg.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ using RooListProxy = RooCollectionProxy<RooArgList>;
class RooExpensiveObjectCache ;
class RooWorkspace ;
namespace RooFit {
namespace Experimental {
class CodegenContext;
}
}

class RooRefArray : public TObjArray {
public:
Expand Down
4 changes: 2 additions & 2 deletions roofit/roofitcore/inc/RooDataHist.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ class RooDataHist : public RooAbsData, public RooDirItem {
double const* wgtErrHiArray() const { return _errHi; }
double const* sumW2Array() const { return _sumw2; }

std::string calculateTreeIndexForCodeSquash(RooAbsArg const *klass, RooFit::CodegenContext &ctx,
std::string calculateTreeIndexForCodeSquash(RooAbsArg const *klass, RooFit::Experimental::CodegenContext &ctx,
const RooAbsCollection &coords, bool reverse = false) const;
std::string declWeightArrayForCodeSquash(RooFit::CodegenContext &ctx,
std::string declWeightArrayForCodeSquash(RooFit::Experimental::CodegenContext &ctx,
bool correctForBinSize) const;

protected:
Expand Down
23 changes: 4 additions & 19 deletions roofit/roofitcore/inc/RooFit/CodegenContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ template <class T>
class RooTemplateProxy;

namespace RooFit {
namespace Experimental {

template <int P>
struct Prio {
Expand Down Expand Up @@ -226,27 +227,11 @@ std::string CodegenContext::buildArgSpanImpl(std::span<const T> arr)
return arrName;
}

template <class Arg_t, int P>
void codegenImpl(Arg_t &arg, RooFit::CodegenContext &ctx, Prio<P> p)
{
if constexpr (std::is_same<Prio<P>, PrioLowest>::value) {
return codegenImpl(arg, ctx);
} else {
return codegenImpl(arg, ctx, p.next());
}
}

template<class Arg_t>
struct CodegenImplCaller {

static auto call(RooAbsArg &arg, RooFit::CodegenContext &ctx) {
return codegenImpl(static_cast<Arg_t&>(arg), ctx, PrioHighest{});
}

};
void declareDispatcherCode(std::string const &funcName);

void codegen(RooAbsArg &arg, RooFit::CodegenContext &ctx);
void codegen(RooAbsArg &arg, CodegenContext &ctx);

} // namespace Experimental
} // namespace RooFit

#endif
5 changes: 4 additions & 1 deletion roofit/roofitcore/res/RooFitImplHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
#include <RooAbsReal.h>

#include <sstream>
#include <vector>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

class RooAbsPdf;
class RooAbsData;
Expand Down Expand Up @@ -100,6 +101,8 @@ namespace Detail {

std::string makeValidVarName(std::string const &in);

void replaceAll(std::string &inOut, std::string_view what, std::string_view with);

} // namespace Detail
} // namespace RooFit

Expand Down
22 changes: 8 additions & 14 deletions roofit/roofitcore/src/RooClassFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ instantiate objects.
#include "RooWorkspace.h"
#include "RooGlobalFunc.h"
#include "RooAbsPdf.h"
#include "RooFitImplHelpers.h"

#include <ROOT/StringUtils.hxx>

Expand Down Expand Up @@ -401,15 +402,6 @@ std::string getFromVarSpans(std::vector<std::string> const &alist)
return ss.str();
}

/// Replace all occurrences of `what` with `with` inside of `inOut`.
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
{
for (std::string::size_type pos{}; inOut.npos != (pos = inOut.find(what.data(), pos, what.length()));
pos += with.length()) {
inOut.replace(pos, what.length(), with.data(), with.length());
}
}

inline bool isSpecial(char c)
{
return c != '_' && !std::isalnum(c);
Expand Down Expand Up @@ -560,9 +552,11 @@ class CLASS_NAME : public BASE_NAME {
};
namespace RooFit {
namespace Experimental {
void codegenImpl(CLASS_NAME &arg, CodegenContext &ctx);
} // namespace Experimental
} // namespace RooFit
)";
Expand Down Expand Up @@ -687,7 +681,7 @@ CLASS_NAME::CLASS_NAME(const char *name, const char *title,
}
}

cf << "void RooFit::codegenImpl(CLASS_NAME &arg, RooFit::CodegenContext &ctx)\n"
cf << "void RooFit::Experimental::codegenImpl(CLASS_NAME &arg, RooFit::Experimental::CodegenContext &ctx)\n"
<< "{\n"
<< " ctx.addResult(&arg, ctx.buildCall(\"CLASS_NAME_evaluate\", " << varsGetters.str() << "));\n"
<<"}\n";
Expand Down Expand Up @@ -796,10 +790,10 @@ void CLASS_NAME::generateEvent(int code)
std::ofstream ocf(className + ".cxx");
std::string headerCode = hf.str();
std::string sourceCode = cf.str();
replaceAll(headerCode, "CLASS_NAME", className);
replaceAll(sourceCode, "CLASS_NAME", className);
replaceAll(headerCode, "BASE_NAME", baseName);
replaceAll(sourceCode, "BASE_NAME", baseName);
RooFit::Detail::replaceAll(headerCode, "CLASS_NAME", className);
RooFit::Detail::replaceAll(sourceCode, "CLASS_NAME", className);
RooFit::Detail::replaceAll(headerCode, "BASE_NAME", baseName);
RooFit::Detail::replaceAll(sourceCode, "BASE_NAME", baseName);
ohf << headerCode;
ocf << sourceCode;

Expand Down
4 changes: 2 additions & 2 deletions roofit/roofitcore/src/RooDataHist.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ Int_t RooDataHist::getIndex(const RooAbsCollection& coord, bool fast) const {
return calcTreeIndex(coord, fast);
}

std::string RooDataHist::declWeightArrayForCodeSquash(RooFit::CodegenContext &ctx,
std::string RooDataHist::declWeightArrayForCodeSquash(RooFit::Experimental::CodegenContext &ctx,
bool correctForBinSize) const
{
std::vector<double> vals(_arrSize);
Expand All @@ -1008,7 +1008,7 @@ std::string RooDataHist::declWeightArrayForCodeSquash(RooFit::CodegenContext &ct
return ctx.buildArg(vals);
}

std::string RooDataHist::calculateTreeIndexForCodeSquash(RooAbsArg const * /*klass*/, RooFit::CodegenContext &ctx,
std::string RooDataHist::calculateTreeIndexForCodeSquash(RooAbsArg const * /*klass*/, RooFit::Experimental::CodegenContext &ctx,
const RooAbsCollection &coords, bool reverse) const
{
assert(coords.size() == _vars.size());
Expand Down
49 changes: 45 additions & 4 deletions roofit/roofitcore/src/RooFit/CodegenContext.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ bool startsWith(std::string_view str, std::string_view prefix)
} // namespace

namespace RooFit {
namespace Experimental {

/// @brief Adds (or overwrites) the string representing the result of a node.
/// @param key The name of the node to add the result for.
Expand Down Expand Up @@ -73,7 +74,7 @@ std::string const &CodegenContext::getResult(RooAbsArg const &arg)
}

// Now, recursively call translate into the current argument to load the correct result.
RooFit::codegen(const_cast<RooAbsArg &>(arg), *this);
codegen(const_cast<RooAbsArg &>(arg), *this);

return _nodeNames.at(arg.namePtr());
}
Expand Down Expand Up @@ -309,9 +310,48 @@ CodegenContext::buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::Dat
return funcName;
}

void codegen(RooAbsArg &arg, RooFit::CodegenContext &ctx)
void declareDispatcherCode(std::string const &funcName)
{
using Func = void (*)(RooAbsArg &, RooFit::CodegenContext &);
std::string dispatcherCode = R"(
namespace RooFit {
namespace Experimental {
template <class Arg_t, int P>
auto FUNC_NAME(Arg_t &arg, CodegenContext &ctx, Prio<P> p)
{
if constexpr (std::is_same<Prio<P>, PrioLowest>::value) {
return FUNC_NAME(arg, ctx);
} else {
return FUNC_NAME(arg, ctx, p.next());
}
}
template <class Arg_t>
struct Caller_FUNC_NAME {
static auto call(RooAbsArg &arg, CodegenContext &ctx)
{
return FUNC_NAME(static_cast<Arg_t &>(arg), ctx, PrioHighest{});
}
};
} // namespace Experimental
} // namespace RooFit
)";

RooFit::Detail::replaceAll(dispatcherCode, "FUNC_NAME", funcName);
gInterpreter->Declare(dispatcherCode.c_str());
}

void codegen(RooAbsArg &arg, CodegenContext &ctx)
{
static bool codeDeclared = false;
if (!codeDeclared) {
declareDispatcherCode("codegenImpl");
codeDeclared = true;
}

using Func = void (*)(RooAbsArg &, CodegenContext &);

Func func;

Expand All @@ -327,12 +367,13 @@ void codegen(RooAbsArg &arg, RooFit::CodegenContext &ctx)
} else {
// Can probably done with CppInterop in the future to avoid string manipulation.
std::stringstream cmd;
cmd << "&RooFit::CodegenImplCaller<" << tclass->GetName() << ">::call;";
cmd << "&RooFit::Experimental::Caller_codegenImpl<" << tclass->GetName() << ">::call;";
func = reinterpret_cast<Func>(gInterpreter->ProcessLine(cmd.str().c_str()));
dispatchMap[tclass] = func;
}

return func(arg, ctx);
}

} // namespace Experimental
} // namespace RooFit
9 changes: 9 additions & 0 deletions roofit/roofitcore/src/RooFitImplHelpers.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,15 @@ std::string makeValidVarName(std::string const &in)
return out;
}

/// Replace all occurrences of `what` with `with` inside of `inOut`.
void replaceAll(std::string &inOut, std::string_view what, std::string_view with)
{
for (std::string::size_type pos{}; inOut.npos != (pos = inOut.find(what.data(), pos, what.length()));
pos += with.length()) {
inOut.replace(pos, what.length(), with.data(), with.length());
}
}

} // namespace Detail
} // namespace RooFit

Expand Down
2 changes: 1 addition & 1 deletion roofit/roofitcore/src/RooFuncWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &
return found != spans.end() ? found->second.size() : -1;
});

RooFit::CodegenContext ctx;
RooFit::Experimental::CodegenContext ctx;

// First update the result variable of params in the compute graph to in[<position>].
int idx = 0;
Expand Down

0 comments on commit 2d48ca9

Please sign in to comment.