1414#include " ROOT/RDF/ColumnReaderUtils.hxx"
1515#include " ROOT/RDF/GraphNode.hxx"
1616#include " ROOT/RDF/RActionBase.hxx"
17+ #include " ROOT/RDF/RFilterBase.hxx"
18+ #include " ROOT/RDF/RJittedFilter.hxx"
1719#include " ROOT/RDF/RLoopManager.hxx"
1820
1921#include < cstddef> // std::size_t
@@ -37,8 +39,12 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {
3739 // Template needed to avoid dependency on ActionHelpers.hxx
3840 Helper fHelper ;
3941
40- // / Pointer to the previous node in this branch of the computation graph
41- std::vector<std::shared_ptr<PrevNode>> fPrevNodes ;
42+ // If the PrevNode is a RJittedFilter, our collection of previous nodes will have to use the RFilterBase type:
43+ // we'll have a RJittedFilter for the nominal case, but the others will be concrete filters.
44+ using PrevNodeCommon_t = std::conditional_t <std::is_same_v<PrevNode, ROOT::Detail::RDF::RJittedFilter>,
45+ ROOT::Detail::RDF::RFilterBase, PrevNode>;
46+ // / Previous nodes in the computation graph. First element is nominal, others are varied.
47+ std::vector<std::shared_ptr<PrevNodeCommon_t>> fPrevNodes ;
4248
4349 // / Column readers per slot and per input column
4450 std::vector<std::vector<RColumnReaderBase *>> fValues ;
@@ -51,13 +57,46 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {
5157
5258 ROOT::RDF::SampleCallback_t GetSampleCallback () final { return fHelper .GetSampleCallback (); }
5359
60+ void AppendVariedPrevNodes ()
61+ {
62+ // This method only makes sense if we're appending the varied filters to the list after the nominal
63+ assert (fPrevNodes .size () == 1 );
64+ const auto ¤tVariations = GetVariations ();
65+
66+ // If this node hangs from the RLoopManager itself, just use that as the upstream node for each variation
67+ auto nominalPrevNode = fPrevNodes .begin ();
68+ if (static_cast <ROOT::Detail::RDF::RNodeBase *>(nominalPrevNode->get ()) == fLoopManager ) {
69+ fPrevNodes .resize (1 + currentVariations.size (), *nominalPrevNode);
70+ return ;
71+ }
72+
73+ // Otherwise, append one varied filter per variation
74+ const auto &prevVariations = (*nominalPrevNode)->GetVariations ();
75+
76+ fPrevNodes .reserve (1 + prevVariations.size ());
77+ // Get valid iterator after resizing
78+ nominalPrevNode = fPrevNodes .begin ();
79+
80+ // Need to populate parts of the computation graph for which we have empty shells, e.g. RJittedFilters
81+ if (!currentVariations.empty ())
82+ fLoopManager ->Jit ();
83+ for (const auto &variation : currentVariations) {
84+ if (IsStrInVec (variation, prevVariations)) {
85+ fPrevNodes .emplace_back (
86+ std::static_pointer_cast<PrevNodeCommon_t>((*nominalPrevNode)->GetVariedFilter (variation)));
87+ } else {
88+ fPrevNodes .push_back (*nominalPrevNode);
89+ }
90+ }
91+ }
92+
5493public:
5594 RActionSnapshot (Helper &&h, const std::vector<std::string> &columns,
5695 const std::vector<const std::type_info *> &colTypeIDs, std::shared_ptr<PrevNode> pd,
5796 const RColumnRegister &colRegister)
5897 : RActionBase (pd->GetLoopManagerUnchecked (), columns, colRegister, pd->GetVariations ()),
5998 fHelper (std::move (h)),
60- fPrevNodes {std::move (pd)},
99+ fPrevNodes {std::static_pointer_cast<PrevNodeCommon_t> (pd)},
61100 fValues (GetNSlots ()),
62101 fColTypeIDs (colTypeIDs)
63102 {
@@ -69,26 +108,7 @@ public:
69108 fIsDefine .push_back (colRegister.IsDefineOrAlias (columns[i]));
70109
71110 if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
72- if (const auto &variations = GetVariations (); !variations.empty ()) {
73- // Get pointers to previous nodes of all systematics
74- fPrevNodes .reserve (1 + variations.size ());
75- auto nominalFilter = fPrevNodes .front ();
76- if (static_cast <RNodeBase *>(nominalFilter.get ()) == fLoopManager ) {
77- // just fill this with the RLoopManager N times
78- fPrevNodes .resize (1 + variations.size (), nominalFilter);
79- } else {
80- // create varied versions of the previous filter node
81- const auto &prevVariations = nominalFilter->GetVariations ();
82- for (const auto &variation : variations) {
83- if (IsStrInVec (variation, prevVariations)) {
84- fPrevNodes .emplace_back (
85- std::static_pointer_cast<PrevNode>(nominalFilter->GetVariedFilter (variation)));
86- } else {
87- fPrevNodes .emplace_back (nominalFilter);
88- }
89- }
90- }
91- }
111+ AppendVariedPrevNodes ();
92112 }
93113 }
94114
@@ -251,7 +271,8 @@ public:
251271 std::unique_ptr<RActionBase> CloneAction (void *newResult) final
252272 {
253273 return std::make_unique<RActionSnapshot>(fHelper .CallMakeNew (newResult), GetColumnNames (), fColTypeIDs ,
254- fPrevNodes .front (), GetColRegister ());
274+ std::static_pointer_cast<PrevNode>(fPrevNodes .front ()),
275+ GetColRegister ());
255276 }
256277};
257278
0 commit comments