diff --git a/tree/treeplayer/inc/ROOT/TDFInterface.hxx b/tree/treeplayer/inc/ROOT/TDFInterface.hxx index ced134dae62f78365066d862605bb397b0c7d790..d0c7864c97b961332cc58c542af2d8d341e17ece 100644 --- a/tree/treeplayer/inc/ROOT/TDFInterface.hxx +++ b/tree/treeplayer/inc/ROOT/TDFInterface.hxx @@ -139,6 +139,13 @@ std::shared_ptr<T> *MakeSharedOnHeap(const std::shared_ptr<T> &shPtr) bool AtLeastOneEmptyString(const std::vector<std::string_view> strings); +/* The following functions upcast shared ptrs to TFilter, TCustomColumn, TRange to their parent class (***Base). + * Shared ptrs to TLoopManager are just copied, as well as shared ptrs to ***Base classes. */ +std::shared_ptr<TFilterBase> UpcastNode(const std::shared_ptr<TFilterBase> ptr); +std::shared_ptr<TCustomColumnBase> UpcastNode(const std::shared_ptr<TCustomColumnBase> ptr); +std::shared_ptr<TRangeBase> UpcastNode(const std::shared_ptr<TRangeBase> ptr); +std::shared_ptr<TLoopManager> UpcastNode(const std::shared_ptr<TLoopManager> ptr); + } // namespace TDF } // namespace Internal @@ -181,6 +188,21 @@ class TInterface { const std::shared_ptr<Proxied> fProxiedPtr; ///< Smart pointer to the graph node encapsulated by this TInterface. const std::weak_ptr<TLoopManager> fImplWeakPtr; ///< Weak pointer to the TLoopManager at the root of the graph. public: + /// \cond HIDDEN_SYMBOLS + // Template conversion operator, meant to use to convert TInterfaces of certain node types to TInterfaces of base + // classes of those node types, e.g. TInterface<TFilter<F,P>> -> TInterface<TFilterBase>. + // It is used implicitly when a call to Filter or Define is jitted: the jitted call must convert the + // TInterface returned by the jitted transformations to a TInterface<***Base> before returning. + // Must be public because it is cling that uses it. + template <typename NewProxied> + operator TInterface<NewProxied>() + { + static_assert(std::is_base_of<NewProxied, Proxied>::value, + "TInterface<T> can only be converted to TInterface<BaseOfT>"); + return TInterface<NewProxied>(fProxiedPtr, fImplWeakPtr); + } + /// \endcond + //////////////////////////////////////////////////////////////////////////// /// \brief Append a filter to the call graph. /// \param[in] f Function, lambda expression, functor class or any other callable object. It must return a `bool` @@ -201,7 +223,7 @@ public: /// it is executed once per entry. If its result is requested more than /// once, the cached result is served. template <typename F, typename std::enable_if<!std::is_convertible<F, std::string>::value, int>::type = 0> - TInterface<TFilterBase> Filter(F f, const ColumnNames_t &columns = {}, std::string_view name = "") + TInterface<TDFDetail::TFilter<F, Proxied>> Filter(F f, const ColumnNames_t &columns = {}, std::string_view name = "") { TDFInternal::CheckFilter(f); auto loopManager = GetDataFrameChecked(); @@ -210,7 +232,7 @@ public: using F_t = TDFDetail::TFilter<F, Proxied>; auto FilterPtr = std::make_shared<F_t>(std::move(f), validColumnNames, *fProxiedPtr, name); loopManager->Book(FilterPtr); - return TInterface<TFilterBase>(FilterPtr, fImplWeakPtr); + return TInterface<F_t>(FilterPtr, fImplWeakPtr); } //////////////////////////////////////////////////////////////////////////// @@ -221,7 +243,7 @@ public: /// /// Refer to the first overload of this method for the full documentation. template <typename F, typename std::enable_if<!std::is_convertible<F, std::string>::value, int>::type = 0> - TInterface<TFilterBase> Filter(F f, std::string_view name) + TInterface<TDFDetail::TFilter<F, Proxied>> Filter(F f, std::string_view name) { // The sfinae is there in order to pick up the overloaded method which accepts two strings // rather than this template method. @@ -236,7 +258,7 @@ public: /// /// Refer to the first overload of this method for the full documentation. template <typename F> - TInterface<TFilterBase> Filter(F f, const std::initializer_list<std::string> &columns) + TInterface<TDFDetail::TFilter<F, Proxied>> Filter(F f, const std::initializer_list<std::string> &columns) { return Filter(f, ColumnNames_t{columns}); } @@ -278,7 +300,8 @@ public: /// /// An exception is thrown if the name of the new column is already in use. template <typename F, typename std::enable_if<!std::is_convertible<F, std::string>::value, int>::type = 0> - TInterface<TCustomColumnBase> Define(std::string_view name, F expression, const ColumnNames_t &columns = {}) + TInterface<TDFDetail::TCustomColumn<F, Proxied>> + Define(std::string_view name, F expression, const ColumnNames_t &columns = {}) { auto loopManager = GetDataFrameChecked(); TDFInternal::CheckTmpBranch(name, loopManager->GetTree()); @@ -288,7 +311,7 @@ public: const std::string nameInt(name); auto BranchPtr = std::make_shared<B_t>(nameInt, std::move(expression), validColumnNames, *fProxiedPtr); loopManager->Book(BranchPtr); - return TInterface<TCustomColumnBase>(BranchPtr, fImplWeakPtr); + return TInterface<B_t>(BranchPtr, fImplWeakPtr); } //////////////////////////////////////////////////////////////////////////// @@ -337,11 +360,13 @@ public: auto df = GetDataFrameChecked(); auto tree = df->GetTree(); std::stringstream snapCall; + auto upcastNode = TDFInternal::UpcastNode(fProxiedPtr); + TInterface<TTraits::TakeFirstParameter_t<decltype(upcastNode)>> upcastInterface(fProxiedPtr, fImplWeakPtr); // build a string equivalent to - // "reinterpret_cast</nodetype/*>(this)->Snapshot<Ts...>(treename,filename,*reinterpret_cast<ColumnNames_t*>(&columnList))" + // "(TInterface<nodetype*>*)(this)->Snapshot<Ts...>(treename,filename,*(ColumnNames_t*)(&columnList))" snapCall << "if (gROOTMutex) gROOTMutex->UnLock();"; // black magic: avoids a deadlock in the interpreter - snapCall << "reinterpret_cast<ROOT::Experimental::TDF::TInterface<" << GetNodeTypeName() << ">*>(" << this - << ")->Snapshot<"; + snapCall << "reinterpret_cast<ROOT::Experimental::TDF::TInterface<" << upcastInterface.GetNodeTypeName() << ">*>(" + << &upcastInterface << ")->Snapshot<"; bool first = true; for (auto &b : columnList) { if (!first) @@ -421,7 +446,7 @@ public: /// \param[in] stride Process one entry every `stride` entries. Must be strictly greater than 0. /// /// Ranges are only available if EnableImplicitMT has _not_ been called. Multi-thread ranges are not supported. - TInterface<TRangeBase> Range(unsigned int start, unsigned int stop, unsigned int stride = 1) + TInterface<TDFDetail::TRange<Proxied>> Range(unsigned int start, unsigned int stop, unsigned int stride = 1) { // check invariants if (stride == 0 || (stop != 0 && stop < start)) @@ -433,7 +458,7 @@ public: using Range_t = TDFDetail::TRange<Proxied>; auto RangePtr = std::make_shared<Range_t>(start, stop, stride, *fProxiedPtr); df->Book(RangePtr); - TInterface<TRangeBase> tdf_r(RangePtr, fImplWeakPtr); + TInterface<TDFDetail::TRange<Proxied>> tdf_r(RangePtr, fImplWeakPtr); return tdf_r; } @@ -442,7 +467,7 @@ public: /// \param[in] stop Total number of entries that will be processed before stopping. 0 means "never stop". /// /// See the other Range overload for a detailed description. - TInterface<TRangeBase> Range(unsigned int stop) { return Range(0, stop, 1); } + TInterface<TDFDetail::TRange<Proxied>> Range(unsigned int stop) { return Range(0, stop, 1); } //////////////////////////////////////////////////////////////////////////// /// \brief Execute a user-defined function on each entry (*instant action*) @@ -1056,11 +1081,16 @@ private: const std::string transformInt(transformation); const std::string nameInt(nodeName); const std::string expressionInt(expression); - const auto thisTypeName = "ROOT::Experimental::TDF::TInterface<" + GetNodeTypeName() + ">"; - return TDFInternal::JitTransformation(this, transformInt, thisTypeName, nameInt, expressionInt, branches, - tmpBranches, tmpBookedBranches, tree); + auto upcastNode = TDFInternal::UpcastNode(fProxiedPtr); + TInterface<TypeTraits::TakeFirstParameter_t<decltype(upcastNode)>> upcastInterface(upcastNode, fImplWeakPtr); + const auto thisTypeName = "ROOT::Experimental::TDF::TInterface<" + upcastInterface.GetNodeTypeName() + ">"; + return TDFInternal::JitTransformation(&upcastInterface, transformInt, thisTypeName, nameInt, expressionInt, + branches, tmpBranches, tmpBookedBranches, tree); } + /// Return string containing fully qualified type name of the node pointed by fProxied. + /// The method is only defined for TInterface<{TFilterBase,TCustomColumnBase,TRangeBase,TLoopManager}> as it should + /// only be called on "upcast" TInterfaces. inline std::string GetNodeTypeName(); // Type was specified by the user, no need to infer it @@ -1092,7 +1122,9 @@ private: const auto &tmpBranches = loopManager->GetBookedBranches(); auto tree = loopManager->GetTree(); auto rOnHeap = TDFInternal::MakeSharedOnHeap(r); - auto toJit = TDFInternal::JitBuildAndBook(validColumnNames, GetNodeTypeName(), fProxiedPtr.get(), + auto upcastNode = TDFInternal::UpcastNode(fProxiedPtr); + TInterface<TypeTraits::TakeFirstParameter_t<decltype(upcastNode)>> upcastInterface(upcastNode, fImplWeakPtr); + auto toJit = TDFInternal::JitBuildAndBook(validColumnNames, upcastInterface.GetNodeTypeName(), upcastNode.get(), typeid(std::shared_ptr<ActionResultType>), typeid(ActionType), rOnHeap, tree, nSlots, tmpBranches); loopManager->Jit(toJit); diff --git a/tree/treeplayer/src/TDFInterface.cxx b/tree/treeplayer/src/TDFInterface.cxx index 1e087af4aaba666dcfdee4579ddb02ee2a84235c..d809b25f39693b949cf95e69431b1f591bcd13dd 100644 --- a/tree/treeplayer/src/TDFInterface.cxx +++ b/tree/treeplayer/src/TDFInterface.cxx @@ -131,9 +131,15 @@ Long_t JitTransformation(void *thisPtr, const std::string &methodName, const std ss << "){ return " << expression << ";}"; auto filterLambda = ss.str(); + // The TInterface type to convert the result to. For example, Filter returns a TInterface<TFilter<F,P>> but when + // returning it from a jitted call we need to convert it to TInterface<TFilterBase> as we are missing information + // on types F and P at compile time. + const auto targetTypeName = std::string("ROOT::Experimental::TDF::TInterface<ROOT::Detail::TDF::") + + (methodName == "Define" ? "TCustomColumnBase" : "TFilterBase") + ">"; + // Here we have two cases: filter and column ss.str(""); - ss << "((" << interfaceTypeName << "*)" << thisPtr << ")->" << methodName << "("; + ss << targetTypeName << "(((" << interfaceTypeName << "*)" << thisPtr << ")->" << methodName << "("; if (methodName == "Define") { ss << "\"" << name << "\", "; } @@ -149,7 +155,7 @@ Long_t JitTransformation(void *thisPtr, const std::string &methodName, const std ss << ", \"" << name << "\""; } - ss << ");"; + ss << "));"; TInterpreter::EErrorCode interpErrCode; auto retVal = gInterpreter->ProcessLine(ss.str().c_str(), &interpErrCode); @@ -235,6 +241,27 @@ bool AtLeastOneEmptyString(const std::vector<std::string_view> strings) } return false; } + +std::shared_ptr<TFilterBase> UpcastNode(const std::shared_ptr<TFilterBase> ptr) +{ + return ptr; +} + +std::shared_ptr<TCustomColumnBase> UpcastNode(const std::shared_ptr<TCustomColumnBase> ptr) +{ + return ptr; +} + +std::shared_ptr<TRangeBase> UpcastNode(const std::shared_ptr<TRangeBase> ptr) +{ + return ptr; +} + +std::shared_ptr<TLoopManager> UpcastNode(const std::shared_ptr<TLoopManager> ptr) +{ + return ptr; +} + } // end ns TDF } // end ns Internal } // end ns ROOT