Skip to content

Commit

Permalink
[RF] Avoid referencing RooFuncWrapper inside code generation context
Browse files Browse the repository at this point in the history
It was a bit awkward that the RooFuncWrapper instantiated a code
generation context, which itself had to use the RooFuncWrapper via a
reference.

This commit suggests to refactor the code such that the context doesn't
need to know anything about the RooFuncWrapper.
  • Loading branch information
guitargeek committed Nov 18, 2024
1 parent 9e2df7e commit 29366b0
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 133 deletions.
21 changes: 10 additions & 11 deletions roofit/roofitcore/inc/RooFit/Detail/CodeSquashContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,11 @@ class RooTemplateProxy;

namespace RooFit {

namespace Experimental {
class RooFuncWrapper;
}

namespace Detail {

/// @brief A class to maintain the context for squashing of RooFit models into code.
class CodeSquashContext {
public:
CodeSquashContext(std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes, std::vector<double> &xlarr,
Experimental::RooFuncWrapper &wrapper);

void addResult(RooAbsArg const *key, std::string const &value);
void addResult(const char *key, std::string const &value);

Expand All @@ -68,7 +61,6 @@ class CodeSquashContext {
}

void addToGlobalScope(std::string const &str);
std::string assembleCode(std::string const &returnExpr);
void addVecObs(const char *key, int idx);

void addToCodeBody(RooAbsArg const *klass, std::string const &in);
Expand Down Expand Up @@ -113,9 +105,15 @@ class CodeSquashContext {
std::string buildArg(std::span<const double> arr);
std::string buildArg(std::span<const int> arr) { return buildArgSpanImpl(arr); }

std::vector<double> const &xlArr() { return _xlArr; }

void collectFunction(std::string const &name);
std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }

std::string
buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes);

Experimental::RooFuncWrapper *_wrapper = nullptr;
auto const &outputSizes() const { return _nodeOutputSizes; }

private:
template <class T>
Expand Down Expand Up @@ -189,8 +187,9 @@ class CodeSquashContext {
/// Mainly used for placing decls outside of loops.
std::string _tempScope;
/// @brief A map to keep track of list names as assigned by addResult.
std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> listNames;
std::vector<double> &_xlArr;
std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> _listNames;
std::vector<double> _xlArr;
std::vector<std::string> _collectedFunctions;
};

template <>
Expand Down
9 changes: 2 additions & 7 deletions roofit/roofitcore/inc/RooFuncWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,16 @@ class RooFuncWrapper final : public RooAbsReal {

void writeDebugMacro(std::string const &) const;

std::string declareFunction(std::string const &funcBody);
void collectFunction(std::string const &funcName) { _collectedFunctions.emplace_back(funcName); }
std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }

std::string buildCode(RooAbsReal const &head);

protected:
double evaluate() const override;

private:
void updateGradientVarBuffer() const;

void loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data,
RooSimultaneous const *simPdf);
std::map<RooFit::Detail::DataKey, std::span<const double>>
loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf);

void buildFuncAndGradFunctors();

Expand All @@ -96,7 +92,6 @@ class RooFuncWrapper final : public RooAbsReal {
mutable std::vector<double> _gradientVarBuffer;
std::vector<double> _observables;
std::map<RooFit::Detail::DataKey, ObsInfo> _obsInfos;
std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
std::vector<double> _xlArr;
std::vector<std::string> _collectedFunctions;
};
Expand Down
5 changes: 1 addition & 4 deletions roofit/roofitcore/src/RooAddition.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ in the two sets.
#include "RooNLLVarNew.h"
#include "RooMsgService.h"
#include "RooBatchCompute.h"
#include "RooFuncWrapper.h"

#ifdef ROOFIT_LEGACY_EVAL_BACKEND
#include "RooNLLVar.h"
Expand Down Expand Up @@ -175,9 +174,7 @@ void RooAddition::translate(RooFit::Detail::CodeSquashContext &ctx) const
if (i < _set.size()) result += '+';
continue;
}
auto &wrp = *ctx._wrapper;
auto funcName = wrp.declareFunction(wrp.buildCode(*component));
result += funcName + "(params, obs, xlArr)";
result += ctx.buildFunction(*component, ctx.outputSizes()) + "(params, obs, xlArr)";
++i;
if (i < _set.size()) result += '+';
}
Expand Down
76 changes: 56 additions & 20 deletions roofit/roofitcore/src/RooFit/Detail/CodeSquashContext.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,26 @@

#include <RooFit/Detail/CodeSquashContext.h>

#include "RooFuncWrapper.h"

#include "RooFitImplHelpers.h"

#include <TInterpreter.h>

#include <algorithm>
#include <cctype>

namespace RooFit {
namespace {

namespace Detail {

CodeSquashContext::CodeSquashContext(std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes,
std::vector<double> &xlarr, Experimental::RooFuncWrapper &wrapper)
: _wrapper{&wrapper}, _nodeOutputSizes(outputSizes), _xlArr(xlarr)
bool startsWith(std::string_view str, std::string_view prefix)
{
return str.size() >= prefix.size() && 0 == str.compare(0, prefix.size(), prefix);
}

}

namespace RooFit {

namespace Detail {

/// @brief Adds (or overwrites) the string representing the result of a node.
/// @param key The name of the node to add the result for.
/// @param value The new name to assign/overwrite.
Expand Down Expand Up @@ -83,14 +86,6 @@ void CodeSquashContext::addToGlobalScope(std::string const &str)
_globalScope += str;
}

/// @brief Assemble and return the final code with the return expression and global statements.
/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
/// @return The final body of the function.
std::string CodeSquashContext::assembleCode(std::string const &returnExpr)
{
return _globalScope + _code + "\n return " + returnExpr + ";\n";
}

/// @brief Since the squashed code represents all observables as a single flattened array, it is important
/// to keep track of the start index for a vector valued observable which can later be expanded to access the correct
/// element. For example, a vector valued variable x with 10 entries will be squashed to obs[start_idx + i].
Expand Down Expand Up @@ -227,8 +222,8 @@ std::string CodeSquashContext::buildArg(RooAbsCollection const &in)
return "nullptr";
}

auto it = listNames.find(in.uniqueId().value());
if (it != listNames.end())
auto it = _listNames.find(in.uniqueId().value());
if (it != _listNames.end())
return it->second;

std::string savedName = getTmpVarName();
Expand All @@ -245,7 +240,7 @@ std::string CodeSquashContext::buildArg(RooAbsCollection const &in)

addToCodeBody(declStrm.str(), canSaveOutside);

listNames.insert({in.uniqueId().value(), savedName});
_listNames.insert({in.uniqueId().value(), savedName});
return savedName;
}

Expand All @@ -269,7 +264,48 @@ bool CodeSquashContext::isScopeIndependent(RooAbsArg const *in) const
/// This is useful to dump the standalone C++ code for the computation graph.
void CodeSquashContext::collectFunction(std::string const &name)
{
_wrapper->collectFunction(name);
_collectedFunctions.emplace_back(name);
}

/// @brief Assemble and return the final code with the return expression and global statements.
/// @param returnExpr The string representation of what the squashed function should return, usually the head node.
/// @return The name of the declared function.
std::string CodeSquashContext::buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes)
{
CodeSquashContext ctx;
ctx._nodeOutputSizes = outputSizes;
ctx._vecObsIndices = _vecObsIndices;
// We only want to take over parameters and observables
for (auto const& item : _nodeNames) {
if (startsWith(item.second, "params[") || startsWith(item.second, "obs[")) {
ctx._nodeNames.insert(item);
}
}
ctx._xlArr = _xlArr;
ctx._collectedFunctions = _collectedFunctions;

static int iCodegen = 0;
auto funcName = "roo_codegen_" + std::to_string(iCodegen++);

std::string funcBody = ctx.getResult(arg);
funcBody = ctx._globalScope + ctx._code + "\n return " + funcBody + ";\n";

// Declare the function
std::stringstream bodyWithSigStrm;
bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
<< funcBody << "\n}";
ctx._collectedFunctions.emplace_back(funcName);
if (!gInterpreter->Declare(bodyWithSigStrm.str().c_str())) {
std::stringstream errorMsg;
errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
throw std::runtime_error(errorMsg.str().c_str());
}

_xlArr = ctx._xlArr;
_collectedFunctions = ctx._collectedFunctions;

return funcName;
}

} // namespace Detail
Expand Down
96 changes: 36 additions & 60 deletions roofit/roofitcore/src/RooFuncWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,50 @@ RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal &
_absReal = std::make_unique<RooEvaluatorWrapper>(obj, const_cast<RooAbsData *>(data), false, "", simPdf, false);
}

std::string func;

// Get the parameters.
RooArgSet paramSet;
obj.getParameters(data ? data->get() : nullptr, paramSet);

// Load the parameters and observables.
loadParamsAndData(&obj, paramSet, data, simPdf);
auto spans = loadParamsAndData(paramSet, data, simPdf);

// Set up the code generation context
std::map<RooFit::Detail::DataKey, std::size_t> nodeOutputSizes =
RooFit::BatchModeDataHelpers::determineOutputSizes(obj, [&spans](RooFit::Detail::DataKey key) -> int {
auto found = spans.find(key);
return found != spans.end() ? found->second.size() : -1;
});

RooFit::Detail::CodeSquashContext ctx;

// First update the result variable of params in the compute graph to in[<position>].
int idx = 0;
for (RooAbsArg *param : _params) {
ctx.addResult(param, "params[" + std::to_string(idx) + "]");
idx++;
}

func = buildCode(obj);
for (auto const &item : _obsInfos) {
const char *obsName = item.first->GetName();
// If the observable is scalar, set name to the start idx. else, store
// the start idx and later set the the name to obs[start_idx + curr_idx],
// here curr_idx is defined by a loop producing parent node.
if (item.second.size == 1) {
ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
} else {
ctx.addResult(obsName, "obs");
ctx.addVecObs(obsName, item.second.idx);
}
}

gInterpreter->Declare("#pragma cling optimize(2)");

// Declare the function and create its derivative.
_funcName = declareFunction(func);
_funcName = ctx.buildFunction(obj, nodeOutputSizes);
_func = reinterpret_cast<Func>(gInterpreter->ProcessLine((_funcName + ";").c_str()));

_xlArr = ctx.xlArr();
_collectedFunctions = ctx.collectedFunctions();
}

RooFuncWrapper::RooFuncWrapper(const RooFuncWrapper &other, const char *name)
Expand All @@ -87,8 +115,8 @@ RooFuncWrapper::RooFuncWrapper(const RooFuncWrapper &other, const char *name)
{
}

void RooFuncWrapper::loadParamsAndData(RooAbsArg const *head, RooArgSet const &paramSet, const RooAbsData *data,
RooSimultaneous const *simPdf)
std::map<RooFit::Detail::DataKey, std::span<const double>>
RooFuncWrapper::loadParamsAndData(RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf)
{
// Extract observables
std::stack<std::vector<double>> vectorBuffers; // for data loading
Expand Down Expand Up @@ -124,32 +152,7 @@ void RooFuncWrapper::loadParamsAndData(RooAbsArg const *head, RooArgSet const &p
}
_gradientVarBuffer.resize(_params.size());

if (head) {
_nodeOutputSizes = RooFit::BatchModeDataHelpers::determineOutputSizes(
*head, [&spans](RooFit::Detail::DataKey key) -> int {
auto found = spans.find(key);
return found != spans.end() ? found->second.size() : -1;
});
}
}

std::string RooFuncWrapper::declareFunction(std::string const &funcBody)
{
static int iFuncWrapper = 0;
auto funcName = "roo_func_wrapper_" + std::to_string(iFuncWrapper++);

// Declare the function
std::stringstream bodyWithSigStrm;
bodyWithSigStrm << "double " << funcName << "(double* params, double const* obs, double const* xlArr) {\n"
<< funcBody << "\n}";
_collectedFunctions.emplace_back(funcName);
if (!gInterpreter->Declare(bodyWithSigStrm.str().c_str())) {
std::stringstream errorMsg;
errorMsg << "Function " << funcName << " could not be compiled. See above for details.";
oocoutE(nullptr, InputArguments) << errorMsg.str() << std::endl;
throw std::runtime_error(errorMsg.str().c_str());
}
return funcName;
return spans;
}

void RooFuncWrapper::createGradient()
Expand Down Expand Up @@ -209,33 +212,6 @@ void RooFuncWrapper::gradient(const double *x, double *g) const
_grad(const_cast<double *>(x), _observables.data(), _xlArr.data(), g);
}

std::string RooFuncWrapper::buildCode(RooAbsReal const &head)
{
RooFit::Detail::CodeSquashContext ctx(_nodeOutputSizes, _xlArr, *this);

// First update the result variable of params in the compute graph to in[<position>].
int idx = 0;
for (RooAbsArg *param : _params) {
ctx.addResult(param, "params[" + std::to_string(idx) + "]");
idx++;
}

for (auto const &item : _obsInfos) {
const char *name = item.first->GetName();
// If the observable is scalar, set name to the start idx. else, store
// the start idx and later set the the name to obs[start_idx + curr_idx],
// here curr_idx is defined by a loop producing parent node.
if (item.second.size == 1) {
ctx.addResult(name, "obs[" + std::to_string(item.second.idx) + "]");
} else {
ctx.addResult(name, "obs");
ctx.addVecObs(name, item.second.idx);
}
}

return ctx.assembleCode(ctx.getResult(head));
}

/// @brief Dumps a macro "filename.C" that can be used to test and debug the generated code and gradient.
void RooFuncWrapper::writeDebugMacro(std::string const &filename) const
{
Expand Down
Loading

0 comments on commit 29366b0

Please sign in to comment.