Skip to content

Commit 126b0a3

Browse files
committed
[df] Consider jitted nodes in Snapshot with variations
This commit introduces two enhancements to the Snapshot with variations implementation: - Snapshot with variations must call RLoopManager::Jit before trying to request varied filters from upstream nodes. This is necessary because the upstream node could be a jitted node itself. - Snapshot with variations must deal with the possibility of having a different nominal upstream node type than the other varied upstream node types. This aligns the implementation witht the one of RVariedAction for what concerns dealing with jitted nodes. Fixes #20320
1 parent f80ee47 commit 126b0a3

File tree

2 files changed

+85
-24
lines changed

2 files changed

+85
-24
lines changed

tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "ROOT/RDF/ColumnReaderUtils.hxx"
1515
#include "ROOT/RDF/GraphNode.hxx"
1616
#include "ROOT/RDF/RActionBase.hxx"
17+
#include "ROOT/RDF/RFilterBase.hxx"
18+
#include "ROOT/RDF/RJittedFilter.hxx"
1719
#include "ROOT/RDF/RLoopManager.hxx"
1820

1921
#include <cstddef> // std::size_t
@@ -37,8 +39,12 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {
3739
// Template needed to avoid dependency on ActionHelpers.hxx
3840
Helper fHelper;
3941

40-
/// Pointer to the previous node in this branch of the computation graph
41-
std::vector<std::shared_ptr<PrevNode>> fPrevNodes;
42+
// If the PrevNode is a RJittedFilter, our collection of previous nodes will have to use the RFilterBase type:
43+
// we'll have a RJittedFilter for the nominal case, but the others will be concrete filters.
44+
using PrevNodeCommon_t = std::conditional_t<std::is_same_v<PrevNode, ROOT::Detail::RDF::RJittedFilter>,
45+
ROOT::Detail::RDF::RFilterBase, PrevNode>;
46+
/// Previous nodes in the computation graph. First element is nominal, others are varied.
47+
std::vector<std::shared_ptr<PrevNodeCommon_t>> fPrevNodes;
4248

4349
/// Column readers per slot and per input column
4450
std::vector<std::vector<RColumnReaderBase *>> fValues;
@@ -51,13 +57,46 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {
5157

5258
ROOT::RDF::SampleCallback_t GetSampleCallback() final { return fHelper.GetSampleCallback(); }
5359

60+
void AppendVariedPrevNodes()
61+
{
62+
// This method only makes sense if we're appending the varied filters to the list after the nominal
63+
assert(fPrevNodes.size() == 1);
64+
const auto &currentVariations = GetVariations();
65+
66+
// If this node hangs from the RLoopManager itself, just use that as the upstream node for each variation
67+
auto nominalPrevNode = fPrevNodes.begin();
68+
if (static_cast<ROOT::Detail::RDF::RNodeBase *>(nominalPrevNode->get()) == fLoopManager) {
69+
fPrevNodes.resize(1 + currentVariations.size(), *nominalPrevNode);
70+
return;
71+
}
72+
73+
// Otherwise, append one varied filter per variation
74+
const auto &prevVariations = (*nominalPrevNode)->GetVariations();
75+
76+
// Need to populate parts of the computation graph for which we have empty shells, e.g. RJittedFilters
77+
if (!prevVariations.empty())
78+
fLoopManager->Jit();
79+
80+
fPrevNodes.reserve(1 + prevVariations.size());
81+
// Get valid iterator after resizing
82+
nominalPrevNode = fPrevNodes.begin();
83+
for (const auto &variation : currentVariations) {
84+
if (IsStrInVec(variation, prevVariations)) {
85+
fPrevNodes.emplace_back(
86+
std::static_pointer_cast<PrevNodeCommon_t>((*nominalPrevNode)->GetVariedFilter(variation)));
87+
} else {
88+
fPrevNodes.push_back(*nominalPrevNode);
89+
}
90+
}
91+
}
92+
5493
public:
5594
RActionSnapshot(Helper &&h, const std::vector<std::string> &columns,
5695
const std::vector<const std::type_info *> &colTypeIDs, std::shared_ptr<PrevNode> pd,
5796
const RColumnRegister &colRegister)
5897
: RActionBase(pd->GetLoopManagerUnchecked(), columns, colRegister, pd->GetVariations()),
5998
fHelper(std::move(h)),
60-
fPrevNodes{std::move(pd)},
99+
fPrevNodes{std::static_pointer_cast<PrevNodeCommon_t>(pd)},
61100
fValues(GetNSlots()),
62101
fColTypeIDs(colTypeIDs)
63102
{
@@ -69,26 +108,7 @@ public:
69108
fIsDefine.push_back(colRegister.IsDefineOrAlias(columns[i]));
70109

71110
if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
72-
if (const auto &variations = GetVariations(); !variations.empty()) {
73-
// Get pointers to previous nodes of all systematics
74-
fPrevNodes.reserve(1 + variations.size());
75-
auto nominalFilter = fPrevNodes.front();
76-
if (static_cast<RNodeBase *>(nominalFilter.get()) == fLoopManager) {
77-
// just fill this with the RLoopManager N times
78-
fPrevNodes.resize(1 + variations.size(), nominalFilter);
79-
} else {
80-
// create varied versions of the previous filter node
81-
const auto &prevVariations = nominalFilter->GetVariations();
82-
for (const auto &variation : variations) {
83-
if (IsStrInVec(variation, prevVariations)) {
84-
fPrevNodes.emplace_back(
85-
std::static_pointer_cast<PrevNode>(nominalFilter->GetVariedFilter(variation)));
86-
} else {
87-
fPrevNodes.emplace_back(nominalFilter);
88-
}
89-
}
90-
}
91-
}
111+
AppendVariedPrevNodes();
92112
}
93113
}
94114

@@ -251,7 +271,8 @@ public:
251271
std::unique_ptr<RActionBase> CloneAction(void *newResult) final
252272
{
253273
return std::make_unique<RActionSnapshot>(fHelper.CallMakeNew(newResult), GetColumnNames(), fColTypeIDs,
254-
fPrevNodes.front(), GetColRegister());
274+
std::static_pointer_cast<PrevNode>(fPrevNodes.front()),
275+
GetColRegister());
255276
}
256277
};
257278

tree/dataframe/test/dataframe_snapshotWithVariations.cxx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,43 @@ TEST(RDFVarySnapshot, SnapshotVirtualClass)
494494
}
495495
}
496496
}
497+
498+
// https://github.com/root-project/root/issues/20320
499+
TEST(RDFVarySnapshot, GH20320)
500+
{
501+
struct FileRAII {
502+
const char *fPath{};
503+
504+
FileRAII(const char *path) : fPath(path) {}
505+
506+
~FileRAII() { std::remove(fPath); }
507+
} outputGuard{"dataframe_snapshot_with_variations_regression_gh20330.root"};
508+
509+
ROOT::RDataFrame df{1};
510+
511+
auto df_def = df.Define("val", []() { return 2; });
512+
513+
auto df_var =
514+
df_def.Vary("val", [](int val) { return ROOT::RVecI{val - 1, val + 1}; }, {"val"}, {"down", "up"}, "var");
515+
516+
// Jitted filters used to break the Snapshot because:
517+
// - It did not JIT the RJittedFilter before requesting for the varied filters
518+
// - It did not take into account that the previous nodes of the Snapshot could be of different types
519+
auto df_fil = df_var.Filter("val > 0");
520+
521+
ROOT::RDF::RSnapshotOptions opts;
522+
opts.fIncludeVariations = true;
523+
auto snap = df_fil.Snapshot("tree", outputGuard.fPath, {"val"}, opts);
524+
525+
auto take_val = snap->Take<int>("val");
526+
auto take_var_up = snap->Take<int>("val__var_up");
527+
auto take_var_down = snap->Take<int>("val__var_down");
528+
529+
EXPECT_EQ(take_val->size(), 1);
530+
EXPECT_EQ(take_var_up->size(), 1);
531+
EXPECT_EQ(take_var_down->size(), 1);
532+
533+
EXPECT_EQ(take_var_down->front(), 1);
534+
EXPECT_EQ(take_val->front(), 2);
535+
EXPECT_EQ(take_var_up->front(), 3);
536+
}

0 commit comments

Comments
 (0)