diff --git a/tree/dataframe/inc/ROOT/RDFHelpers.hxx b/tree/dataframe/inc/ROOT/RDFHelpers.hxx index 51aa7a65cad7ec349ae96b73c73b0f535f307788..289c23796047e0ad359352295618e6e5e4ccff9a 100644 --- a/tree/dataframe/inc/ROOT/RDFHelpers.hxx +++ b/tree/dataframe/inc/ROOT/RDFHelpers.hxx @@ -25,7 +25,6 @@ #include <memory> #include <fstream> #include <iostream> -#include "TString.h" namespace ROOT { diff --git a/tree/dataframe/inc/ROOT/RDFInterface.hxx b/tree/dataframe/inc/ROOT/RDFInterface.hxx index 61439e2d5aff5b9d5bc3e53b8a4b8af214d27296..93492e2c9a61d18d621ddf52f0a58048cfedb029 100644 --- a/tree/dataframe/inc/ROOT/RDFInterface.hxx +++ b/tree/dataframe/inc/ROOT/RDFInterface.hxx @@ -548,8 +548,8 @@ public: cacheCall << RDFInternal::ColumnName2ColumnTypeName(c, nsID, tree, fDataSource, isCustom) << ", "; }; if (!columnList.empty()) - snapCall.seekp(-2, snapCall.cur); // remove the last ", - snapCall << ">(*reinterpret_cast<std::vector<std::string>*>(" // vector<string> should be ColumnNames_t + cacheCall.seekp(-2, cacheCall.cur); // remove the last ", + cacheCall << ">(*reinterpret_cast<std::vector<std::string>*>(" // vector<string> should be ColumnNames_t << RDFInternal::PrettyPrintAddr(&columnList) << "));"; // jit cacheCall, return result TInterpreter::EErrorCode errorCode; @@ -1425,6 +1425,26 @@ public: /// This is not an action nor a transformation, just a query to the RDataFrame object. std::vector<std::string> GetFilterNames() { return RDFInternal::GetFilterNames(fProxiedPtr); } + /// \brief Returns the names of the defined columns + /// + /// This is not an action nor a transformation, just a simple utility to + /// get the columns names that have been defined up to the node. + /// If no custom column has been defined, e.g. on a root node, it returns an + /// empty array. + ColumnNames_t GetDefinedColumnNames() + { + ColumnNames_t definedColumns; + + auto columns = fCustomColumns.GetColumns(); + + for(auto column: columns){ + if (!RDFInternal::IsInternalColumn(column.first) && !column.second->IsDataSourceColumn()) + definedColumns.emplace_back(column.first); + } + + return definedColumns; + } + // clang-format off //////////////////////////////////////////////////////////////////////////// /// \brief Execute a user-defined accumulation operation on the processed column values in each processing slot diff --git a/tree/dataframe/test/dataframe_interface.cxx b/tree/dataframe/test/dataframe_interface.cxx index d492873b36a3d632a2aa67d8e9b8cc315225ea33..f35559bd94a022afbc50357e2f871e426edc1db3 100644 --- a/tree/dataframe/test/dataframe_interface.cxx +++ b/tree/dataframe/test/dataframe_interface.cxx @@ -244,6 +244,43 @@ TEST(RDataFrameInterface, GetFilterNamesFromLoopManagerNoFilters) EXPECT_EQ(comparison, names); } +TEST(RDataFrameInterface, GetDefinedColumnNamesFromScratch) +{ + RDataFrame f(1); + auto dummyGen = []() { return 1; }; + auto names = f.Define("a", dummyGen).Define("b", dummyGen).Define("tdfDummy_", dummyGen).GetDefinedColumnNames(); + std::sort(names.begin(), names.end()); + EXPECT_STREQ("a", names[0].c_str()); + EXPECT_STREQ("b", names[1].c_str()); + EXPECT_EQ(2U, names.size()); +} + +TEST(RDataFrameInterface, GetDefinedColumnNamesFromTree) +{ + TTree t("t", "t"); + int a, b; + t.Branch("a", &a); + t.Branch("b", &b); + RDataFrame tdf(t); + + auto dummyGen = []() { return 1; }; + auto names = tdf.Define("d_a", dummyGen).Define("d_b", dummyGen).GetDefinedColumnNames(); + + EXPECT_EQ(2U, names.size()); + std::sort(names.begin(), names.end()); + EXPECT_STREQ("d_a", names[0].c_str()); + EXPECT_STREQ("d_b", names[1].c_str()); +} + +TEST(RDataFrameInterface, GetDefinedColumnNamesFromSource) +{ + std::unique_ptr<RDataSource> tds(new RTrivialDS(1)); + RDataFrame tdf(std::move(tds)); + auto names = tdf.Define("b", []() { return 1; }).GetDefinedColumnNames(); + EXPECT_EQ(1U, names.size()); + EXPECT_STREQ("b", names[0].c_str()); +} + TEST(RDataFrameInterface, DefaultColumns) { RDataFrame tdf(8);