Skip to content

Commit

Permalink
[df] Integrate Report action with systematic variations
Browse files Browse the repository at this point in the history
This commit introduces the behaviour of the Report action in case
systematic variations were requested, i.e. it is now possible to call
VariationsFor(report). In order to create a different report per
variation, the proper branch of the computation graph with the filters
of the "varied universe" must be retrieved. This requires knowing
which is the variation being requested when cloning the ReportHelper
instance. To this end, the signature of the MakeNew protocol used by the
action helpers is extended to optionally take the name of the variation
being requested. When creating the varied helpers, the action also
passes the name of the variation. A test for this new functionality has
been added.
  • Loading branch information
vepadulano committed Jan 13, 2025
1 parent 878290b commit 8c61c06
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 59 deletions.
42 changes: 24 additions & 18 deletions tree/dataframe/inc/ROOT/RDF/ActionHelpers.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public:

std::string GetActionName() { return "Count"; }

CountHelper MakeNew(void *newResult)
CountHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<ULong64_t> *>(newResult);
return CountHelper(result, fCounts.size());
Expand Down Expand Up @@ -226,7 +226,13 @@ public:

std::string GetActionName() { return "Report"; }

// TODO implement MakeNew. Requires some smartness in passing the appropriate previous node.
ReportHelper MakeNew(void *newResult, std::string_view variation = "nominal")
{
auto &&result = *static_cast<std::shared_ptr<RCutFlowReport> *>(newResult);
return ReportHelper{result,
std::static_pointer_cast<RNode_t>(fNode->GetVariedFilter(std::string(variation))).get(),
fReturnEmptyReport};
}
};

/// This helper fills TH1Ds for which no axes were specified by buffering the fill values to pick good axes limits.
Expand Down Expand Up @@ -330,7 +336,7 @@ public:
return std::string(fResultHist->IsA()->GetName()) + "\\n" + std::string(fResultHist->GetName());
}

BufferedFillHelper MakeNew(void *newResult)
BufferedFillHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<Hist_t> *>(newResult);
result->Reset();
Expand Down Expand Up @@ -548,7 +554,7 @@ public:
}

template <typename H = HIST>
FillHelper MakeNew(void *newResult)
FillHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<H> *>(newResult);
ResetIfPossible(result.get());
Expand Down Expand Up @@ -636,7 +642,7 @@ public:

Result_t &PartialUpdate(unsigned int slot) { return *fGraphs[slot]; }

FillTGraphHelper MakeNew(void *newResult)
FillTGraphHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<TGraph> *>(newResult);
result->Set(0);
Expand Down Expand Up @@ -742,7 +748,7 @@ public:

Result_t &PartialUpdate(unsigned int slot) { return *fGraphAsymmErrors[slot]; }

FillTGraphAsymmErrorsHelper MakeNew(void *newResult)
FillTGraphAsymmErrorsHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<TGraphAsymmErrors> *>(newResult);
result->Set(0);
Expand Down Expand Up @@ -808,7 +814,7 @@ public:

std::string GetActionName() { return "Take"; }

TakeHelper MakeNew(void *newResult)
TakeHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<COLL> *>(newResult);
result->clear();
Expand Down Expand Up @@ -861,7 +867,7 @@ public:

std::string GetActionName() { return "Take"; }

TakeHelper MakeNew(void *newResult)
TakeHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<std::vector<T>> *>(newResult);
result->clear();
Expand Down Expand Up @@ -906,7 +912,7 @@ public:

std::string GetActionName() { return "Take"; }

TakeHelper MakeNew(void *newResult)
TakeHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<COLL> *>(newResult);
result->clear();
Expand Down Expand Up @@ -958,7 +964,7 @@ public:

std::string GetActionName() { return "Take"; }

TakeHelper MakeNew(void *newResult)
TakeHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<typename decltype(fColls)::value_type *>(newResult);
result->clear();
Expand Down Expand Up @@ -1033,7 +1039,7 @@ public:

std::string GetActionName() { return "Min"; }

MinHelper MakeNew(void *newResult)
MinHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<ResultType> *>(newResult);
return MinHelper(result, fMins.size());
Expand Down Expand Up @@ -1083,7 +1089,7 @@ public:

std::string GetActionName() { return "Max"; }

MaxHelper MakeNew(void *newResult)
MaxHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<ResultType> *>(newResult);
return MaxHelper(result, fMaxs.size());
Expand Down Expand Up @@ -1166,7 +1172,7 @@ public:

std::string GetActionName() { return "Sum"; }

SumHelper MakeNew(void *newResult)
SumHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<ResultType> *>(newResult);
*result = NeutralElement(*result, -1);
Expand Down Expand Up @@ -1217,7 +1223,7 @@ public:

std::string GetActionName() { return "Mean"; }

MeanHelper MakeNew(void *newResult)
MeanHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<double> *>(newResult);
return MeanHelper(result, fSums.size());
Expand Down Expand Up @@ -1265,7 +1271,7 @@ public:

std::string GetActionName() { return "StdDev"; }

StdDevHelper MakeNew(void *newResult)
StdDevHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<double> *>(newResult);
return StdDevHelper(result, fCounts.size());
Expand Down Expand Up @@ -1634,7 +1640,7 @@ public:
* also involves changing the name of the output file, otherwise the cloned
* Snapshot would overwrite the same file.
*/
SnapshotHelper MakeNew(void *newName)
SnapshotHelper MakeNew(void *newName, std::string_view /*variation*/ = "nominal")
{
const std::string finalName = *reinterpret_cast<const std::string *>(newName);
return SnapshotHelper{
Expand Down Expand Up @@ -1830,7 +1836,7 @@ public:
* also involves changing the name of the output file, otherwise the cloned
* Snapshot would overwrite the same file.
*/
SnapshotHelperMT MakeNew(void *newName)
SnapshotHelperMT MakeNew(void *newName, std::string_view /*variation*/ = "nominal")
{
const std::string finalName = *reinterpret_cast<const std::string *>(newName);
return SnapshotHelperMT{fNSlots, finalName, fDirName, fTreeName,
Expand Down Expand Up @@ -1899,7 +1905,7 @@ public:

std::string GetActionName() { return "Aggregate"; }

AggregateHelper MakeNew(void *newResult)
AggregateHelper MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<U> *>(newResult);
return AggregateHelper(fAggregate, fMerge, result, fAggregators.size());
Expand Down
7 changes: 4 additions & 3 deletions tree/dataframe/inc/ROOT/RDF/RAction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,15 @@ public:

std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&results) final
{
const auto nVariations = GetVariations().size();
auto &&variations = GetVariations();
auto &&nVariations = variations.size();
assert(results.size() == nVariations);

std::vector<Helper> helpers;
helpers.reserve(nVariations);

for (auto &&res : results)
helpers.emplace_back(fHelper.CallMakeNew(res));
for (decltype(nVariations) i{}; i < nVariations; i++)
helpers.emplace_back(fHelper.CallMakeNew(results[i], variations[i]));

return std::unique_ptr<RActionBase>(new RVariedAction<Helper, PrevNode, ColumnTypes_t>{
std::move(helpers), GetColumnNames(), fPrevNodePtr, GetColRegister()});
Expand Down
35 changes: 24 additions & 11 deletions tree/dataframe/inc/ROOT/RDF/RActionImpl.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
#include <stdexcept> // std::logic_error
#include <utility> // std::declval

namespace ROOT::Internal::RDF {
template <typename T>
class HasMakeNew {
template <typename C, typename = decltype(std::declval<C>().MakeNew((void *)(nullptr), std::string_view{}))>
static std::true_type Test(int);
template <typename C>
static std::false_type Test(...);

public:
static constexpr bool value = decltype(Test<T>(0))::value;
};
} // namespace ROOT::Internal::RDF

namespace ROOT {
namespace Detail {
namespace RDF {
Expand Down Expand Up @@ -48,18 +61,18 @@ public:
throw std::logic_error("This action does not support callbacks!");
}

template <typename T = Helper>
auto CallMakeNew(void *typeErasedResSharedPtr) -> decltype(std::declval<T>().MakeNew(typeErasedResSharedPtr))
{
return static_cast<Helper *>(this)->MakeNew(typeErasedResSharedPtr);
}

template <typename... Args>
[[noreturn]] Helper CallMakeNew(void *, Args...)
Helper CallMakeNew(void *typeErasedResSharedPtr, std::string_view variation = "nominal")
{
const auto &actionName = static_cast<Helper *>(this)->GetActionName();
throw std::logic_error("The MakeNew method is not implemented for this action helper (" + actionName +
"). Cannot Vary its result.");
if constexpr (ROOT::Internal::RDF::HasMakeNew<Helper>::value)
return static_cast<Helper *>(this)->MakeNew(typeErasedResSharedPtr, variation);
else {
// Avoid unused parameter warning with GCC
(void)typeErasedResSharedPtr;
(void)variation;
const auto &actionName = static_cast<Helper *>(this)->GetActionName();
throw std::logic_error("The MakeNew method is not implemented for this action helper (" + actionName +
"). Cannot Vary its result.");
}
}

// Helper functions for RMergeableValue
Expand Down
7 changes: 4 additions & 3 deletions tree/dataframe/inc/ROOT/RDF/RInterface.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -2967,9 +2967,10 @@ public:
/// * `ROOT::RDF::SampleCallback_t GetSampleCallback()`: if present, it must return a callable with the
/// appropriate signature (see ROOT::RDF::SampleCallback_t) that will be invoked at the beginning of the processing
/// of every sample, as in DefinePerSample().
/// * `Helper MakeNew(void *newResult)`: if implemented, it enables varying the action's result with VariationsFor(). It takes a
/// type-erased new result that can be safely cast to a `std::shared_ptr<Result_t> *` (a pointer to shared pointer) and should
/// be used as the action's output result.
/// * `Helper MakeNew(void *newResult, std::string_view variation = "nominal")`: if implemented, it enables varying
/// the action's result with VariationsFor(). It takes a type-erased new result that can be safely cast to a
/// `std::shared_ptr<Result_t> *` (a pointer to shared pointer) and should be used as the action's output result.
/// The function optionally takes the name of the current variation which could be useful in customizing its behaviour.
///
/// In case Book is called without specifying column types as template arguments, corresponding typed code will be just-in-time compiled
/// by RDataFrame. In that case the Helper class needs to be known to the ROOT interpreter.
Expand Down
2 changes: 1 addition & 1 deletion tree/dataframe/src/RDataFrame.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ shorthand that automatically generates tags 0 to N-1 (in this case 0 and 1).
interfaces might still evolve and improve based on user feedback. We expect that some aspects of the related
programming model will be streamlined in future versions.
\note Currently, the results of a Snapshot(), Report() or Display() call cannot be varied (i.e. it is not possible to
\note Currently, the results of a Snapshot() or Display() call cannot be varied (i.e. it is not possible to
call \ref ROOT::RDF::Experimental::VariationsFor "VariationsFor()" on them. These limitations will be lifted in future releases.
See the Vary() method for more information and [this tutorial](https://root.cern/doc/master/df106__HiggsToFourLeptons_8C.html)
Expand Down
52 changes: 29 additions & 23 deletions tree/dataframe/test/dataframe_vary.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ struct MyCounter : public ROOT::Detail::RDF::RActionImpl<MyCounter> {

std::string GetActionName() const { return "MyCounter"; }

MyCounter MakeNew(void *newResult)
MyCounter MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto &result = *static_cast<std::shared_ptr<int> *>(newResult);
return MyCounter(result, fPerThreadResults.size());
Expand Down Expand Up @@ -1453,36 +1453,42 @@ TEST_P(RDFVary, VaryReduce)
EXPECT_EQ(hs["x:1"], 55);
}

// Varying Reports is not implemented yet, tracked by https://github.com/root-project/root/issues/10551
TEST_P(RDFVary, VaryReport)
{
auto h = ROOT::RDataFrame(10)
.Define("x", [](ULong64_t e) { return int(e); }, {"rdfentry_"})
.Filter([](int x) { return x > 5; }, {"x"}, "before")
.Vary(
"x",
[](int x) {
return ROOT::RVecI{x - 1, x + 1};
},
{"x"}, 2)
.Filter([](int x) { return x > 7; }, {"x"}, "after")
.Report();
auto &report = *h;
auto rep = ROOT::RDataFrame(10)
.Define("x", [](ULong64_t e) { return int(e); }, {"rdfentry_"})
.Filter([](int x) { return x > 5; }, {"x"}, "before")
.Vary("x", [](int x) { return ROOT::RVecI{x - 1, x + 1}; }, {"x"}, {"down", "up"})
.Filter([](int x) { return x > 7; }, {"x"}, "after")
.Report();

auto reps = VariationsFor(rep);
auto &&report = reps["nominal"];

EXPECT_EQ(report["before"].GetAll(), 10);
EXPECT_FLOAT_EQ(report["before"].GetEff(), 40.);
EXPECT_EQ(report["before"].GetPass(), 4);
EXPECT_EQ(report["after"].GetAll(), 4);
EXPECT_FLOAT_EQ(report["after"].GetEff(), 50.);
EXPECT_EQ(report["after"].GetPass(), 2);
EXPECT_THROW(
try { VariationsFor(h); } catch (const std::logic_error &err) {
const auto msg = "The MakeNew method is not implemented for this action helper (Report). "
"Cannot Vary its result.";
EXPECT_STREQ(err.what(), msg);
throw;
},
std::logic_error);

auto &&report_up = reps["x:up"];

EXPECT_EQ(report_up["before"].GetAll(), 10);
EXPECT_FLOAT_EQ(report_up["before"].GetEff(), 40.);
EXPECT_EQ(report_up["before"].GetPass(), 4);
EXPECT_EQ(report_up["after"].GetAll(), 4);
EXPECT_FLOAT_EQ(report_up["after"].GetEff(), 75.);
EXPECT_EQ(report_up["after"].GetPass(), 3);

auto &&report_down = reps["x:down"];

EXPECT_EQ(report_down["before"].GetAll(), 10);
EXPECT_FLOAT_EQ(report_down["before"].GetEff(), 40.);
EXPECT_EQ(report_down["before"].GetPass(), 4);
EXPECT_EQ(report_down["after"].GetAll(), 4);
EXPECT_FLOAT_EQ(report_down["after"].GetEff(), 25.);
EXPECT_EQ(report_down["after"].GetPass(), 1);
}

TEST_P(RDFVary, VaryStdDev)
Expand Down Expand Up @@ -1630,7 +1636,7 @@ struct HelperWithCallback : ROOT::Detail::RDF::RActionImpl<HelperWithCallback> {
return callback;
}

HelperWithCallback MakeNew(void *newResult)
HelperWithCallback MakeNew(void *newResult, std::string_view /*variation*/ = "nominal")
{
auto newHelper = HelperWithCallback();
newHelper.fResult = *static_cast<std::shared_ptr<Result_t> *>(newResult);
Expand Down

0 comments on commit 8c61c06

Please sign in to comment.